From dcd3db472468672892a9b68c4e63fd01256e19bc Mon Sep 17 00:00:00 2001 From: Salem Derisavi Date: Mon, 13 May 2019 15:48:00 -0400 Subject: [PATCH 001/176] cleanup: removed a piece of code that is redundant now given updates to HalideIR submodule (#3169) --- src/pass/loop_partition.cc | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 04bb9385b156..bcb2608682ee 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -389,20 +389,7 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, for (const auto &kv : partitions) { if (kv.first.second == cond_value) { arith::Interval interval = kv.second.as()->i; - auto intersection = arith::Interval::make_intersection(interval, for_interval); - - // TODO(derisavi): the following if statement needs to be removed as soon as - // TVM uses commit a768f2f0 of HalideIR repo - if (intersection.min.same_as(arith::Interval::pos_inf) || - intersection.max.same_as(arith::Interval::neg_inf)) { - intersection = arith::Interval::nothing(); - } else if (intersection.min.type() == intersection.max.type() && - (intersection.min.type().is_int() || - intersection.min.type().is_uint()) && - can_prove(intersection.min > intersection.max)) { - intersection = arith::Interval::nothing(); - } - + arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval); if (!intersection.is_empty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); From cdaa7bcb72808bd4933b57aa5ff0e976aab89204 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 13 May 2019 14:17:11 -0700 Subject: [PATCH 002/176] add onnx elemwise greater/less (#3186) --- python/tvm/relay/frontend/onnx.py | 19 +++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 2 ++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index eba02e70c865..08a64c37d8df 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -622,6 +622,23 @@ def _impl_v1(cls, inputs, attr, params): extras={'axis':axis})(inputs, {}) #return _op.take(inputs[0], inputs[1], axis) + +class Greater(OnnxOpConverter): + """ Operator logical greater. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.greater(inputs[0], inputs[1]) + + +class Less(OnnxOpConverter): + """ Operator logical less than. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.less(inputs[0], inputs[1]) + + class LRN(OnnxOpConverter): """ Operator converter for Local Response Normalization. """ @@ -836,6 +853,8 @@ def _get_convert_map(opset): 'Selu': Selu.get_converter(opset), 'Elu': Elu.get_converter(opset), 'Exp': Renamer('exp'), + 'Greater': Greater.get_converter(opset), + 'Less': Less.get_converter(opset), 'Log': Renamer('log'), 'Tanh': Renamer('tanh'), 'Pow': Renamer('power'), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f867e73e8c08..77f045aa06cc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -955,6 +955,8 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): verify_binary_ops("Div", x, y, x / y, broadcast=None) verify_binary_ops("Div", x, z, x / z, broadcast=True) verify_binary_ops("Sum", x, y, x + y, broadcast=None) + verify_binary_ops("Greater", x, y, x > y, broadcast=True) + verify_binary_ops("Less", x, y, x < y, broadcast=True) def test_single_ops(): in_shape = (1, 2, 3, 3) From 6d604c0fe25661d963f267bc6f364269810c169c Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 14 May 2019 05:34:16 -0700 Subject: [PATCH 003/176] [RELAY][PASS] detect depthwise conv2d in mac_count pass (#3083) * check in * use groups * CHECK_EQ * trigger CI * Update mac_count.cc * trigger CI * trigger CI --- src/relay/pass/mac_count.cc | 15 ++++++++---- tests/python/relay/test_pass_mac_count.py | 29 +++++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index c9ee4eec0337..3d77fabe6fe9 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -30,7 +30,9 @@ #include #include #include +#include #include +#include "pattern_util.h" namespace tvm { namespace relay { @@ -65,7 +67,7 @@ int64_t ConvMacCount(const Call& call_node) { } Array args = call_node->args; CHECK(args.size() == 2) - << "The number of input arguments of a CONV 2D node should be 2."; + << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; @@ -73,18 +75,21 @@ int64_t ConvMacCount(const Call& call_node) { int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); CHECK(C_ind != -1) - << "There is no input channel dimension."; + << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; CHECK(kernel_size.size() == 2) - << "The dimension of the kernel size in Conv 2D should be 2."; + << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; - int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + CHECK_EQ(input_channel % conv_2d_attr->groups, 0) + << "The number of input channels is not divisble by groups."; + count *= input_channel/conv_2d_attr->groups; return count; } diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index 5a975fd41364..98ba1ad6325d 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for MAC counter.""" +import numpy as np import tvm from tvm import relay @@ -99,7 +100,35 @@ def test_simple_network(): expect_count = 231411712 assert compute_count == expect_count +def test_depthwise_conv2d(): + batch_size = 1 + dshape = (batch_size, 64, 56, 56) + weight_conv = relay.var("weight_depthwiseconv", shape=(64, 1, 3, 3)) + data1 = relay.var("data1", shape=dshape) + data2 = relay.var("data2", shape=dshape) + depthwise_conv2d_1 = relay.nn.conv2d( + data1, + weight_conv, + kernel_size=(3, 3), + padding=(1, 1), + groups=64) + depthwise_conv2d_2 = relay.nn.conv2d( + data2, + weight_conv, + kernel_size=(3, 3), + padding=(1, 1), + groups=64) + add = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + func = relay.Function([data1, data2, weight_conv], + relay.Tuple(tvm.convert([depthwise_conv2d_1, + depthwise_conv2d_2, + add]))) + func = relay.ir_pass.infer_type(func) + compute_count = relay.ir_pass.get_total_mac_number(func) + assert compute_count == 2 * np.prod(dshape) * 3*3 + if __name__ == "__main__": test_conv() test_gemm() test_simple_network() + test_depthwise_conv2d() From ca37de72139d26bbeaffa0620dc43c7c0448480f Mon Sep 17 00:00:00 2001 From: ghostplant Date: Wed, 15 May 2019 01:22:33 +0800 Subject: [PATCH 004/176] Avoid using heavy API to query single attribution (#3179) --- src/codegen/opt/build_cuda_on.cc | 9 +++++---- src/runtime/cuda/cuda_device_api.cc | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/codegen/opt/build_cuda_on.cc b/src/codegen/opt/build_cuda_on.cc index fda239f0766f..e2a788f1bbd4 100644 --- a/src/codegen/opt/build_cuda_on.cc +++ b/src/codegen/opt/build_cuda_on.cc @@ -84,12 +84,13 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { std::vector compile_params; std::vector param_cstrings{}; nvrtcProgram prog; - cudaDeviceProp device_prop; std::string cc = "30"; - cudaError_t e = cudaGetDeviceProperties(&device_prop, 0); + int major, minor; + cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); + cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); - if (e == cudaSuccess) { - cc = std::to_string(device_prop.major) + std::to_string(device_prop.minor); + if (e1 == cudaSuccess && e2 == cudaSuccess) { + cc = std::to_string(major) + std::to_string(minor); } else { LOG(WARNING) << "cannot detect compute capability from your device, " << "fall back to compute_30."; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index f812156f1999..f5d660c56816 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include "cuda_common.h" @@ -73,9 +74,10 @@ class CUDADeviceAPI final : public DeviceAPI { return; } case kDeviceName: { - cudaDeviceProp props; - CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id)); - *rv = std::string(props.name); + std::string name(256, 0); + CUDA_DRIVER_CALL(cuDeviceGetName(&name[0], name.size(), ctx.device_id)); + name.resize(strlen(name.c_str())); + *rv = std::move(name); return; } case kMaxClockRate: { From ab4f8815496745296b4164fe712dbbd9316e8b6f Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 14 May 2019 22:42:34 -0700 Subject: [PATCH 005/176] [Relay][TensorFlow Frontend] SoftPlus Sqrt (#3187) --- python/tvm/relay/frontend/tensorflow.py | 12 ++++++++++ .../frontend/tensorflow/test_forward.py | 22 ++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 48f78837c525..4bd78b47fe54 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -990,6 +990,16 @@ def _impl(inputs, attr, params): transforms={'axis': ('axis', 1)})([inputs[0]], attr) return _impl +def _softplus(): + # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus + def _impl(inputs, attr, params): + exp_out = AttrCvt('exp')(inputs, attr) + inputs.append(tvm.relay.const(1, attr['T'].name)) + rh = tvm.relay.const(1, attr['T'].name) + add_out = _get_relay_op('add')(exp_out, rh) + return _get_relay_op('log')(add_out) + return _impl + def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) @@ -1163,9 +1173,11 @@ def _impl(inputs, attr, params): 'Sign' : AttrCvt('sign'), 'Slice' : _slice(), 'Softmax' : _softmax(), + 'Softplus' : _softplus(), 'SpaceToBatchND' : _space_to_batch_nd(), 'Split' : _split(False), 'SplitV' : _split(True), + 'Sqrt' : AttrCvt('sqrt'), 'Square' : _square(), 'Squeeze' : _squeeze(), 'StridedSlice' : _stridedSlice(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 58bbdab02b84..2f1cc2f6c9a4 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1151,7 +1151,6 @@ def test_forward_placeholder(): graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0') tvm_output = run_tvm_graph(graph_def, data, 'Placeholder') - print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output)) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) ####################################################################### @@ -1440,22 +1439,37 @@ def test_forward_pow_exp(): compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0') def test_forward_log(): - """test Log """ + """test operator Log """ np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) tf.reset_default_graph() in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") tf.log(in_data, name="log") compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') +def test_forward_softplus(): + """test operator Softplus""" + np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") + tf.nn.softplus(in_data, name="softplus") + compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0') + def test_forward_rsqrt(): """test Rsqrt """ np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) tf.reset_default_graph() in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") tf.rsqrt(in_data, name="rsqrt") - print(tf.get_default_graph().as_graph_def()) compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0') +def test_forward_sqrt(): + """test Sqrt """ + np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") + tf.sqrt(in_data, name="sqrt") + compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0') + ####################################################################### # Mean # ---- @@ -1561,6 +1575,8 @@ def test_forward_reduce_prod(): test_forward_pow_exp() test_forward_sign() test_forward_log() + test_forward_softplus() + test_forward_sqrt() test_forward_rsqrt() test_forward_expand_dims() From f83562773d333dd033c33e78c683da42ca8f64aa Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 15 May 2019 13:34:30 -0700 Subject: [PATCH 006/176] [Datatypes] Custom datatypes (#2900) * Register and use custom datatypes in TVM This patch adds the ability to register and use a custom datatype from Python, using the `register_datatype` call. The datatype can then be passed as the `dtype` parameter using the syntax `dtype="custom[]bitsxlanes"`. * Removes extra file * Register custom datatypes with TVM; specify Cast and Add lowering This commit adds functionality for registering custom datatypes with TVM, and furthermore adding custom lowering functions to lower those custom datatypes. This commit only adds lowering for the Cast and Add ops; more ops will be added soon. Check out some custom datatype samples in my repository of samples: https://github.com/gussmith23/tvm-custom-datatype-samples * Register and lower casts from Python * Formatting * Fix include; was including too much * Add comment * Add DatatypeRegistered * Add storage size field to custom datatypes This field indicates the bitwidth of the opaque block of data into which instances of the datatype will be stored, when TVM compiles. For example, if I create a datatype with a storage size of 16, then - Constants of that datatype will be created as unsigned 16-bit ints - Calls to external functions taking that datatype will pass the data as unsigned 16-bit ints - External functions returning that datatype will be assumed to return unsigned 16-bit ints. * Change how lowering funcs (Cast and other ops) are named in registry tvm.datatypes.lower..cast.. becomes tvm.datatypes.lower..Cast.. And fixes some sloppy code around how the other ops were being formatted. * Update Python register_datatype to accept storage size * Oops, left out one cast->Cast change * Look up storage size when parsing `custom[typename]` When we encounter this type string in Python, it will be parsed into a Halide type object in C++. Some of my original code supported this parsing, but we now have to attach the storage type to the type (by setting the bits field). * Change how external calls for casting/other ops are done Firstly, we now use the storage size of the custom type when determining input/output types; e.g. a cast to a custom type with storage size 16 is seen as a call to an external function returning an opaque uint of size 16. Secondly, write a macro to handle the other ops. Originally I thought I could handle these at runtime, with a single `_register_op` global. I transitioned instead to using individual `_register_Add` etc. calls generated with a macro, but I don't remember why. * When encountering a custom type immediate, generate UIntImm * Translate custom types to LLVM type * Generate correct return type in Casts Originally I was assuming that the result type from casts was always a custom datatype, and so I was making the Call return a UInt type. * Use TVM-idiomatic recursion style in DatatypesLowerer This was actually a bug, I'm pretty sure; we wouldn't have recursed deep on any complex programs. As a result of making this change, I also uncovered another potential bug, where the datatypes lowering pass would attempt to lower a Load of a custom type. By commenting out the `Mutate_` for Load, I was able to stop the error from cropping up, but frankly, I'm not satisfied with the solution; how is it that we are able to run codegen when Loads of custom datatypes are present in the IR? I have not written any code, to my knowledge, that will support this. Perhaps Load does not care about the underlying datatype? * Use CHECK * Add comment about which Mutate_s are needed * Add comments * Add GetCustomDatatypeRegistered as an extern C function * Formatting, comments, casting * Change how datatype string is formatted * Use bits() instead of GetStorageSize Use bits() instead of GetStorageSize * Change comment * Add datatype.py * Change registered function name (datatypes->datatype) * Remove GetStorageSize * Format custom datatypes like any other datatype Specifically, we now print the bits and lanes after the `custom[...]` string. * Correctly implement datatype lowering in Python * Remove unneeded include * Make function naming consistent * Use CHECK instead of internal_assert * Rename macro * Formatting * Rename functions * Implement Cast lowering `_datatype_register_op` is now able to lower both binary ops and Casts. * Formatting * Formatting * Clang format, google style * Fix std::string/extern "C" warnings * Formatting * Formatting * Lower Allocates and Loads during datatype lowering This should ensure that there are no custom datatypes remaining once datatype lowering is done. This will allow us to remove the code in the LLVM codegen which deals with custom datatypes. * Revert additions to codegen_llvm.cc which are now unneeded * Pass cpplint on lower_datatypes.cc * Add clarifying comment * Remove datatype lowering registration funcs from C++ * Add CHECKs * Remove TODO * Remove all references to storage size * Move and rename function * Rename function * Remove done TODOs and other handled comments * Remove irrelevant Load code and comments * Comment out the IR node types I'm not sure about yet * Add bfloat16 datatype unittest * Fix MakeConstScalar MakeConstScalar for a custom datatype will now call out to a function which can be registered on a per-datatype basis. The function will take a double and return the equivalent value in the custom datatype format. Note that these code paths are not actually used or tested at the moment. I have not yet written an example which uses const scalars of a custom datatype. * Formatting * Change pass name * Allow users to register whatever lowering function they want Tianqi pointed out that users should be able to register whatever lowering function they want, and should not be constrained to registering lowering functions which just call out to external libraries. I still provide a function for making lowering functions which call out to external libraries, for convenience. * Add clarifying comment * Remove unneeded comment * Remove unneeded function * Rename file * Undo unnecessary change * Undo unnecessary change * Make naming consistent Rename "datatypes" to "custom datatypes" in most contexts. * Revert an artifact of old code * Fix build warnings, add TODO * Lint * Remove unnecessary use of extern C by separating decl and impl * Error checking * Remove TODO * Missed a name change * Lint * Python lint * Correctly format datatype * Move bfloat16 to 3rdparty * "custom_datatypes" --> "datatype" in most places I left the pass as "LowerCustomDatatypes" to indicate that we're not lowering anything other than custom datatypes. Otherwise, everything else has been changed. * Upgrade datatype unittest I used a float calculator to generate some real testcases for the unittest. * Separate public includes and private implementation Specifically, create cleaner decoupling between datatypes stuff in packed_func and the datatype registry implementation. * Formatting * Limit custom datatype codes to >128 * Add TODOs * Fix comment * Formatting * Clean up datatype unittest * Remove un-exported functions in public headers; UIntImm->FloatImm More places where I accidentally was using implementation-only functions in public headers. Additionally, store custom datatype immediates as FloatImms. A later change will add new lowering logic to lower these FloatImms to UIntImms. Plus formatting change. * Lint * Use FloatImm (not UIntImm) to hold immediates of custom datatypes This change switches from using UIntImm to FloatImm for storing immediates of custom datatypes. The value of the number is stored in a double, which should be enough precision for now, for most custom types we will explore in the immediate future. In line with this change, we change the datatype lowering so that FloatImms are lowered to UInts of the appropriate size. Originally, this was going to be done by allowing the user to register a double->uint__t conversion which would be called at compile time to convert the value from the FloatImm to a UInt and store it in a UIntImm. After discussions with Tianqi, we decided to take the simpler route, and lower FloatImms just as we lower all other ops: by replacing them with Call nodes. In this case, presumably the user will Call out to a conversion function in their datatype library. The justification for this decision is due to the functionality added in #1486. This pull request adds the ability to load LLVM bytecode in at compile time. This applies in our case as follows: 1. The user writes their custom datatype programs and registers their lowering functions in the same way we've been doing it so far. All operations over custom datatypes are lowered to Calls to the datatype library. 2. The user compiles their datatype library to LLVM bytecode. 3. At TVM compile time, the user loads the LLVM bytecode. Depending on how the datatype library is written, Clang should be able to perform constant folding over the custom datatype immediates, even if their conversions are done with calls to the library. Additionally adds test to test the FloatImm codepath. * Re-add a change I removed accidentally during rebase * Cleanup * Remove unnecessary TVM_DLLs * Add custom datatype utilities source file to Go runtime pack * Revert "Remove unnecessary TVM_DLLs" This reverts commit 4b742b99557fd3bf0ce6617f033c8b444b74eda4. * Mark bfloat code as TVM_DLL * Moves custom datatype runtime utilities to c_runtime_api.cc * Revert "Add custom datatype utilities source file to Go runtime pack" This reverts commit aecbcde0b2cc09a2693955b77037fe20f93b5bfd. * Move datatype parsing to its own function * Change comments * Remove unneeded function * Formatting * Formatting * Documentation * Add kCustomBegin, use it for checking for custom types * Documentation * Formatting * Move static definition to implementation * Remove comment * Decide toBeLowered before lowering arguments of Expr In the past, e.g. when lowering custom datatypes for an Add, we would lower a and b first, and then decide whether the resulting new Add needed to be lowered based on the (new) types of a and b. Now, instead, we need to check the types of a and b first (to see if they're custom types), and then lower them (so they'll become non-custom types), and then lower the new Add. * Revert "Move datatype parsing to its own function" This reverts commit d554a5881afcf69af1c070d882a7651022703a09. This broke parsing. Will figure this out later. There isn't a really clean way to separate this out given how the rest of the function is written. * Replace comment * Documentation * Remove comment and TVM_DLL * Better error messages * Remove artifact of rebase * Separate datatypes parsing to its own function * Add \returns * Comment changes; add TODO * Refactor tests --- 3rdparty/HalideIR | 2 +- 3rdparty/bfloat16/bfloat16.cc | 80 +++++++++ CMakeLists.txt | 4 + include/tvm/expr_operator.h | 7 + include/tvm/ir_pass.h | 11 ++ include/tvm/runtime/c_runtime_api.h | 2 + include/tvm/runtime/packed_func.h | 37 +++- python/tvm/__init__.py | 3 +- python/tvm/_ffi/runtime_ctypes.py | 14 +- python/tvm/datatype.py | 146 ++++++++++++++++ src/api/api_pass.cc | 1 + src/codegen/datatype/registry.cc | 108 ++++++++++++ src/codegen/datatype/registry.h | 162 ++++++++++++++++++ src/pass/lower_custom_datatypes.cc | 140 +++++++++++++++ src/runtime/c_runtime_api.cc | 46 +++++ .../test_custom_datatypes_mybfloat16.py | 150 ++++++++++++++++ 16 files changed, 908 insertions(+), 5 deletions(-) create mode 100644 3rdparty/bfloat16/bfloat16.cc create mode 100644 python/tvm/datatype.py create mode 100644 src/codegen/datatype/registry.cc create mode 100644 src/codegen/datatype/registry.h create mode 100644 src/pass/lower_custom_datatypes.cc create mode 100644 tests/python/unittest/test_custom_datatypes_mybfloat16.py diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index a768f2f06279..ec9585a5a5df 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit a768f2f0627917659a4d7167eee3190469b9d164 +Subproject commit ec9585a5a5df3de91e8916ac2d27a4a509eac5fc diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc new file mode 100644 index 000000000000..333b534afc08 --- /dev/null +++ b/3rdparty/bfloat16/bfloat16.cc @@ -0,0 +1,80 @@ +/* + Copyright (c) 2019 by Contributors + \file tvm/src/codegen/custom_datatypes/mybfloat16.cc + \brief Small bfloat16 library for use in unittests + + Code originally from Tensorflow; taken and simplified. Original license: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include +#include +#include + +void FloatToBFloat16(const float* src, uint16_t* dst, size_t size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p += 2, q++, size--) { + *q = p[0]; + } +#else + for (; size != 0; p += 2, q++, size--) { + *q = p[1]; + } +#endif +} + +void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p++, q += 2, size--) { + q[0] = *p; + q[1] = 0; + } +#else + for (; size != 0; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } +#endif +} + +void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, + size_t size) { + float a_f, b_f; + BFloat16ToFloat(a, &a_f, 1); + BFloat16ToFloat(b, &b_f, 1); + float out_f = a_f + b_f; + FloatToBFloat16(&out_f, dst, 1); +} + +extern "C" { +TVM_DLL TVM_DLL uint16_t FloatToBFloat16_wrapper(float in) { + uint16_t out; + FloatToBFloat16(&in, &out, 1); + return out; +} + +TVM_DLL float BFloat16ToFloat_wrapper(uint16_t in) { + float out; + BFloat16ToFloat(&in, &out, 1); + return out; +} + +TVM_DLL uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b) { + uint16_t out; + BFloat16Add(&a, &b, &out, 1); + return out; +} +} diff --git a/CMakeLists.txt b/CMakeLists.txt index dceb9f46568e..9f8bbbe24568 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,6 +123,8 @@ file(GLOB_RECURSE RELAY_SRCS ) list(APPEND COMPILER_SRCS ${RELAY_SRCS}) +file(GLOB DATATYPE_SRCS src/codegen/datatype/*.cc) +list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) @@ -152,6 +154,8 @@ if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) endif() +list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc) + if(USE_RPC) message(STATUS "Build with RPC support...") file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 4ef3effaf251..2e1348e00470 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -33,6 +33,7 @@ #include "ir.h" namespace tvm { + /*! * \brief Make a const value with certain data type. * \param t The target type. @@ -551,6 +552,12 @@ inline Expr MakeConstScalar(Type t, ValueType value) { if (t.is_int()) return ir::IntImm::make(t, static_cast(value)); if (t.is_uint()) return ir::UIntImm::make(t, static_cast(value)); if (t.is_float()) return ir::FloatImm::make(t, static_cast(value)); + // For now, we store const scalar values of custom datatypes within doubles; later, during the + // datatypes lowering pass, we will lower the value to its true representation in the format + // specified by the datatype. + // TODO(gus) when do we need to start worrying about doubles not being precise enough? + if (static_cast(t.code()) >= static_cast(kCustomBegin)) + return ir::FloatImm::make(t, static_cast(value)); LOG(FATAL) << "cannot make const for type " << t; return Expr(); } diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 20b56e0676eb..5ef4dc4ed9d7 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -500,6 +500,17 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); */ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); +/*! + * \brief Lower custom datatypes. + * + * See tvm::datatypes::Registry for more information on adding custom datatypes. + * + * \param f The device function to be lowered. + * \param target The target device. + * \return Transformed function. + */ +LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); + /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f992e87ad100..ee3542f90255 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -114,6 +114,8 @@ typedef enum { // The following section of code is used for non-reserved types. kExtReserveEnd = 64U, kExtEnd = 128U, + // The rest of the space is used for custom, user-supplied datatypes + kCustomBegin = 128U, } TVMTypeCode; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 9fcefcbbe4b1..82b3dd469541 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -60,6 +60,29 @@ namespace tvm { class Integer; namespace runtime { + +/*! + * \brief Runtime utility for getting custom type name from code + * \param type_code Custom type code + * \return Custom type name + */ +TVM_DLL std::string GetCustomTypeName(uint8_t type_code); + +/*! + * \brief Runtime utility for checking whether custom type is registered + * \param type_code Custom type code + * \return Bool representing whether type is registered + */ +TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); + +/*! + * \brief Runtime utility for parsing string of the form "custom[]" + * \param s String to parse + * \param scan pointer to parsing pointer, which is scanning across s + * \return type code of custom type parsed + */ +TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); + // forward declarations class TVMArgs; class TVMArgValue; @@ -939,7 +962,11 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { os << "bool"; return os; } - os << TypeCode2Str(t.code); + if (GetCustomTypeRegistered(t.code)) { + os << "custom[" << GetCustomTypeName(t.code) << "]"; + } else { + os << TypeCode2Str(t.code); + } if (t.code == kHandle) return os; os << static_cast(t.bits); if (t.lanes != 1) { @@ -960,7 +987,11 @@ inline std::string TVMType2String(TVMType t) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { return "bool"; } - repr += TypeCode2Str(t.code); + if (GetCustomTypeRegistered(t.code)) { + repr += "custom[" + GetCustomTypeName(t.code) + "]"; + } else { + repr += TypeCode2Str(t.code); + } if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); if (t.lanes != 1) { @@ -994,6 +1025,8 @@ inline TVMType String2TVMType(std::string s) { t.bits = 1; t.lanes = 1; return t; + } else if (s.substr(0, 6) == "custom") { + t.code = ParseCustomDatatype(s, &scan); } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index ce6f0602a572..5765eed0ad8b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -38,12 +38,13 @@ from . import hybrid from . import testing from . import error +from . import datatype from . import ndarray as nd from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, opengl, ext_dev -from ._ffi.runtime_ctypes import TypeCode +from ._ffi.runtime_ctypes import TypeCode, TVMType from ._ffi.ndarray import TVMContext from ._ffi.function import Function from ._ffi.base import TVMError, __version__ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 4ede33a63936..72cff1a10ead 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -91,6 +91,13 @@ def __init__(self, type_str): self.type_code = 4 bits = 64 head = "" + elif head.startswith("custom"): + low, high = head.find('['), head.find(']') + if not low or not high or low >= high: + raise ValueError("Badly formatted custom type string %s" % type_str) + type_name = head[low + 1:high] + self.type_code = _api_internal._datatype_get_type_code(type_name) + head = head[high+1:] else: raise ValueError("Do not know how to handle type %s" % type_str) bits = int(head) if head else bits @@ -100,7 +107,12 @@ def __init__(self, type_str): def __repr__(self): if self.bits == 1 and self.lanes == 1: return "bool" - x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) + if self.type_code in TVMType.CODE2STR: + type_name = TVMType.CODE2STR[self.type_code] + else: + type_name = "custom[%s]" % \ + _api_internal._datatype_get_type_name(self.type_code) + x = "%s%d" % (type_name, self.bits) if self.lanes != 1: x += "x%d" % self.lanes return x diff --git a/python/tvm/datatype.py b/python/tvm/datatype.py new file mode 100644 index 000000000000..df3e3a62a510 --- /dev/null +++ b/python/tvm/datatype.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Custom datatype functionality""" +from __future__ import absolute_import as _abs + +from ._ffi.function import register_func as _register_func +from . import make as _make +from .api import convert +from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm +from ._ffi.runtime_ctypes import TVMType as _TVMType +from . import _api_internal + + +def register(type_name, type_code): + """Register a custom datatype with the given type name and type code + Currently, the type code is manually allocated by the user, and the + user must ensure that no two custom types share the same code. + Generally, this should be straightforward, as the user will be + manually registering all of their custom types. + + Parameters + ---------- + type_name : str + The name of the custom datatype + + type_code : int + The type's code, which should be >= kCustomBegin + """ + _api_internal._datatype_register(type_name, type_code) + + +def get_type_name(type_code): + """Get the type name from the type code + + Parameters + ---------- + type_code : int + The type code + """ + return _api_internal._datatype_get_type_name(type_code) + + +def get_type_code(type_name): + """Get the type code from the type name + + Parameters + ---------- + type_name : str + The type name + """ + return _api_internal._datatype_get_type_code(type_name) + + +def get_type_registered(type_code): + """Get a boolean representing whether the type is registered + + Parameters + ---------- + type_code: int + The type code + """ + return _api_internal._datatype_get_type_registered(type_code) + + +def register_op(lower_func, op_name, target, type_name, src_type_name=None): + """Register an external function which computes the given op. + + Currently, this will only work with Casts and binary expressions + whose arguments are named `a` and `b`. + TODO(gus) figure out what other special cases must be handled by + looking through expr.py. + + Parameters + ---------- + lower_func : function + The lowering function to call. See create_lower_func. + + op_name : str + The name of the operation which the function computes, given by its + Halide::Internal class name (e.g. Add, LE, Cast). + + target : str + The name of codegen target. + + type_name : str + The name of the custom datatype, e.g. posit (but not custom[posit]8). + + src_type_name : str + If op_name is "Cast", then this should be set to the source datatype of + the argument to the Cast. If op_name is not "Cast", this is unused. + """ + + if op_name == "Cast": + assert src_type_name is not None + lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + + type_name + "." + src_type_name + else: + lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + + type_name + _register_func(lower_func_name, lower_func) + + +def create_lower_func(extern_func_name): + """Returns a function which lowers an operation to a function call. + + Parameters + ---------- + extern_func_name : str + The name of the extern "C" function to lower to + """ + + def lower(op): + """ + Takes an op---either a Cast or a binary op (e.g. an Add) and returns a + call to the specified external function, passing the op's argument + (Cast) or arguments (a binary op). The return type of the call depends + on the type of the op: if it is a custom type, then a uint of the same + width as the custom type is returned. Otherwise, the type is + unchanged.""" + dtype = op.dtype + t = _TVMType(dtype) + if get_type_registered(t.type_code): + dtype = "uint" + str(t.bits) + if t.lanes > 1: + dtype += "x" + str(t.lanes) + if isinstance(op, (_Cast, _FloatImm)): + return _make.Call(dtype, extern_func_name, convert([op.value]), + _Call.Extern, None, 0) + return _make.Call(dtype, extern_func_name, convert([op.a, op.b]), + _Call.Extern, None, 0) + + return lower diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 6195aac1b93f..d6c92aee94d1 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -151,6 +151,7 @@ REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerIntrin); +REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerTVMBuiltin); REGISTER_PASS(CombineContextCall); REGISTER_PASS(VerifyMemory); diff --git a/src/codegen/datatype/registry.cc b/src/codegen/datatype/registry.cc new file mode 100644 index 000000000000..28cc58204e8d --- /dev/null +++ b/src/codegen/datatype/registry.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "registry.h" +#include + +namespace tvm { +namespace datatype { + +TVM_REGISTER_GLOBAL("_datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { + datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = datatype::Registry::Global()->GetTypeCode(args[0]); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeName(args[0].operator int()); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_registered").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); +}); + +Registry* Registry::Global() { + static Registry inst; + return &inst; +} + +void Registry::Register(const std::string& type_name, uint8_t type_code) { + CHECK(type_code >= kCustomBegin) << "Please choose a type code >= kCustomBegin for custom types"; + code_to_name_[type_code] = type_name; + name_to_code_[type_name] = type_code; +} + +uint8_t Registry::GetTypeCode(const std::string& type_name) { + CHECK(name_to_code_.find(type_name) != name_to_code_.end()) + << "Type name " << type_name << " not registered"; + return name_to_code_[type_name]; +} + +std::string Registry::GetTypeName(uint8_t type_code) { + CHECK(code_to_name_.find(type_code) != code_to_name_.end()) + << "Type code " << static_cast(type_code) << " not registered"; + return code_to_name_[type_code]; +} + +const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, + uint8_t src_type_code) { + std::ostringstream ss; + ss << "tvm.datatype.lower."; + ss << target << "."; + ss << "Cast" + << "."; + + if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { + ss << datatype::Registry::Global()->GetTypeName(type_code); + } else { + ss << runtime::TypeCode2Str(type_code); + } + + ss << "."; + + if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { + ss << datatype::Registry::Global()->GetTypeName(src_type_code); + } else { + ss << runtime::TypeCode2Str(src_type_code); + } + + return runtime::Registry::Get(ss.str()); +} + +const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) { + std::ostringstream ss; + ss << "tvm.datatype.lower."; + ss << target; + ss << ".FloatImm."; + ss << datatype::Registry::Global()->GetTypeName(type_code); + return runtime::Registry::Get(ss.str()); +} + +uint64_t ConvertConstScalar(uint8_t type_code, double value) { + std::ostringstream ss; + ss << "tvm.datatype.convertconstscalar.float."; + ss << datatype::Registry::Global()->GetTypeName(type_code); + auto make_const_scalar_func = runtime::Registry::Get(ss.str()); + return (*make_const_scalar_func)(value).operator uint64_t(); +} + +} // namespace datatype +} // namespace tvm diff --git a/src/codegen/datatype/registry.h b/src/codegen/datatype/registry.h new file mode 100644 index 000000000000..d2e615765a18 --- /dev/null +++ b/src/codegen/datatype/registry.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_CODEGEN_DATATYPE_REGISTRY_H_ +#define TVM_CODEGEN_DATATYPE_REGISTRY_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace datatype { + +/*! + * \brief Registry for custom datatypes. + * + * Adding custom datatypes currently requires two steps: + * 1. Register the datatype with the registry via a call to + * datatype::Registry::Register. This can also be done in Python + * directly---see the TVM globals registered in the corresponding .cc file. + * Currently, user should manually choose a type name and a type code, + * ensuring that neither conflict with existing types. + * 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to + * lower the custom datatype. In general, these will look like: + * For Casts: tvm.datatype.lower..Cast.. + * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from + * float to myfloat. + * For other ops: tvm.datatype.lower... + * Examples: tvm.datatype.lower.llvm.Add.myfloat + * tvm.datatype.lower.llvm.FloatImm.posit + */ +class Registry { + public: + /*! + * \brief Get the global custom datatype registry singleton + */ + static Registry* Global(); + + /*! + * \brief Register custom datatype + * Register a custom datatype with the given type name and type code. Currently, the type code is + * manually allocated by the user, and the user must ensure that no two custom types share the + * same code. Generally, this should be straightforward, as the user will be manually registering + * all of their custom types. + * \param type_name The name of the type, e.g. "bfloat" + * \param type_code The type code, which should be greater than TVMTypeCode::kExtEnd + */ + void Register(const std::string& type_name, uint8_t type_code); + + /*! + * \brief Get type code from type name + * \param type_name The type name + * \return The type code + */ + uint8_t GetTypeCode(const std::string &type_name); + + /*! + * \brief Get type name from type code + * \param type_code The type code + * \return The type name + */ + std::string GetTypeName(uint8_t type_code); + + /*! + * \brief Get bool representing whether type is registered, given the type code + * \param type_code The type code + * \return bool representing whether the type is registered + */ + inline bool GetTypeRegistered(uint8_t type_code) { + return code_to_name_.find(type_code) != code_to_name_.end(); + } + + /*! + * \brief Get bool representing whether type is registered, given the type name + * \param type_name The type name + * \return bool representing whether the type is registered + */ + inline bool GetTypeRegistered(std::string type_name) { + return name_to_code_.find(type_name) != name_to_code_.end(); + } + + private: + // TODO(gus) is there a typedef for the code? + std::unordered_map code_to_name_; + std::unordered_map name_to_code_; +}; + +/*! + * \brief Convert scalar value to a custom datatype format + * \param type_code The custom datatype to convert to, specified by type code + * \param value The floating point value to convert + * \return The value, encoded in the bits of a uint64_t + */ +uint64_t ConvertConstScalar(uint8_t type_code, double value); + +/*! + * \brief Get lowering function for Cast ops + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype being cast to + * \param src_type_code The datatype being cast from + * \return Lowering function for Cast ops for the provided target, type, and source type + */ +const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, + uint8_t src_type_code); + +/*! + * \brief Get lowering function for FloatImms + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype of the FloatImm + * \return Lowering function for FloatImms for the provided target and type + */ +const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code); + +/*! + * \brief Get lowering function for other ops + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype of the op + * \return Lowering function for other ops for the provided target and type + */ +#define DEFINE_GET_LOWER_FUNC_(OP) \ + inline const runtime::PackedFunc* Get##OP##LowerFunc(const std::string& target, \ + uint8_t type_code) { \ + return runtime::Registry::Get("tvm.datatype.lower." + target + "." #OP "." + \ + datatype::Registry::Global()->GetTypeName(type_code)); \ + } + +DEFINE_GET_LOWER_FUNC_(Add) +DEFINE_GET_LOWER_FUNC_(Sub) +DEFINE_GET_LOWER_FUNC_(Mul) +DEFINE_GET_LOWER_FUNC_(Div) +DEFINE_GET_LOWER_FUNC_(Mod) +DEFINE_GET_LOWER_FUNC_(Min) +DEFINE_GET_LOWER_FUNC_(Max) +DEFINE_GET_LOWER_FUNC_(EQ) +DEFINE_GET_LOWER_FUNC_(NE) +DEFINE_GET_LOWER_FUNC_(LT) +DEFINE_GET_LOWER_FUNC_(LE) +DEFINE_GET_LOWER_FUNC_(GT) +DEFINE_GET_LOWER_FUNC_(GE) +// Later changes may need to add more lowering functions as we support workloads with more ops. + +} // namespace datatype +} // namespace tvm + +#endif // TVM_CODEGEN_DATATYPE_REGISTRY_H_ diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc new file mode 100644 index 000000000000..7598ef49eee0 --- /dev/null +++ b/src/pass/lower_custom_datatypes.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/src/pass/lower_custom_datatypes.cc + * \brief Pass for lowering custom datatypes + */ + +#include +#include +#include +#include "../codegen/datatype/registry.h" + +namespace tvm { +namespace ir { + +/*! + * \brief Helper mutator to implement lowering of custom datatypes. + * + * Lowering datatypes works as follows: for every expression containing a custom + * datatype, we search for a global (registered by the implementer of the custom + * datatype) for lowering this type of expression, and uses it to lower the + * expression. + */ +class CustomDatatypesLowerer : public IRMutator { + public: + explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} + + inline Expr Mutate_(const Cast* op, const Expr& e) final { + auto type_code = op->type.code(); + auto src_type_code = op->value.type().code(); + // If either datatype is a registered custom datatype, we must lower. + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || + datatype::Registry::Global()->GetTypeRegistered(src_type_code); + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + if (toBeLowered) { + auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); + CHECK(lower) << "Cast lowering function for target " << target_ << " destination type " + << static_cast(type_code) << " source type " + << static_cast(src_type_code) << " not found"; + return (*lower)(expr); + } + return expr; + } + + inline Expr Mutate_(const FloatImm* imm, const Expr& e) final { + auto type_code = imm->type.code(); + if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { + auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); + CHECK(lower) << "FloatImm lowering function for target " << target_ << " type " + << static_cast(type_code) << " not found"; + return (*lower)(e); + } + return e; + } + + inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final { + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->type.code()); + Stmt stmt = IRMutator::Mutate_(allocate, s); + allocate = stmt.as(); + + if (toBeLowered) { + auto new_allocate_type = UInt(allocate->type.bits(), allocate->type.lanes()); + return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body, allocate->new_expr, + allocate->free_function); + } + return stmt; + } + + inline Expr Mutate_(const Load* load, const Expr& e) final { + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->type.code()); + Expr expr = IRMutator::Mutate_(load, e); + load = expr.as(); + if (toBeLowered) { + auto new_load_type = UInt(load->type.bits()); + return Load::make(new_load_type, load->buffer_var, load->index, load->predicate); + } + return expr; + } + +#define DEFINE_MUTATE__(OP) \ + inline Expr Mutate_(const OP* op, const Expr& e) final { \ + auto type_code = op->type.code(); \ + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ + Expr expr = IRMutator::Mutate_(op, e); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ + } + + DEFINE_MUTATE__(Add) + DEFINE_MUTATE__(Sub) + DEFINE_MUTATE__(Mul) + DEFINE_MUTATE__(Div) + DEFINE_MUTATE__(Mod) + DEFINE_MUTATE__(Min) + DEFINE_MUTATE__(Max) + DEFINE_MUTATE__(EQ) + DEFINE_MUTATE__(NE) + DEFINE_MUTATE__(LT) + DEFINE_MUTATE__(LE) + DEFINE_MUTATE__(GT) + DEFINE_MUTATE__(GE) + // Later changes may need to add more mutate functions as we support workloads with more ops. + + private: + std::string target_; +}; + +LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { + auto n = make_node(*f.operator->()); + n->body = CustomDatatypesLowerer(target).Mutate(n->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 59cdb7f0a467..20793b4618b3 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -45,6 +45,52 @@ namespace tvm { namespace runtime { +std::string GetCustomTypeName(uint8_t type_code) { + auto f = tvm::runtime::Registry::Get("_datatype_get_type_name"); + CHECK(f) << "Function _datatype_get_type_name not found"; + return (*f)(type_code).operator std::string(); +} + +uint8_t GetCustomTypeCode(const std::string& type_name) { + auto f = tvm::runtime::Registry::Get("_datatype_get_type_code"); + CHECK(f) << "Function _datatype_get_type_code not found"; + return (*f)(type_name).operator int(); +} + +bool GetCustomTypeRegistered(uint8_t type_code) { + auto f = tvm::runtime::Registry::Get("_datatype_get_type_registered"); + CHECK(f) << "Function _datatype_get_type_registered not found"; + return (*f)(type_code).operator bool(); +} + +uint8_t ParseCustomDatatype(const std::string& s, const char** scan) { + CHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string"; + + auto tmp = s.c_str(); + + CHECK(s.c_str() == tmp); + *scan = s.c_str() + 6; + CHECK(s.c_str() == tmp); + if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s; + CHECK(s.c_str() == tmp); + *scan += 1; + CHECK(s.c_str() == tmp); + size_t custom_name_len = 0; + CHECK(s.c_str() == tmp); + while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']') + ++custom_name_len; + CHECK(s.c_str() == tmp); + if (*(*scan + custom_name_len) != ']') + LOG(FATAL) << "expected closing brace after 'custom' type in" << s; + CHECK(s.c_str() == tmp); + *scan += custom_name_len + 1; + CHECK(s.c_str() == tmp); + + auto type_name = s.substr(7, custom_name_len); + CHECK(s.c_str() == tmp); + return GetCustomTypeCode(type_name); +} + class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; diff --git a/tests/python/unittest/test_custom_datatypes_mybfloat16.py b/tests/python/unittest/test_custom_datatypes_mybfloat16.py new file mode 100644 index 000000000000..99c6cf5f268b --- /dev/null +++ b/tests/python/unittest/test_custom_datatypes_mybfloat16.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from ctypes import * +import topi +import tvm.ir_pass as ir_pass +import numpy as np + +tgt = "llvm" + + +def setup(): + # You must first load the library containing the datatype implementation. + # In this case, we have built the test functions used below right into TVM. + # CDLL("libmybfloat16.so", RTLD_GLOBAL) + + tvm.datatype.register("bfloat", 129) + + tvm.datatype.register_op( + tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast", + "llvm", "bfloat", "float") + tvm.datatype.register_op( + tvm.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast", + "llvm", "float", "bfloat") + tvm.datatype.register_op( + tvm.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", + "bfloat") + tvm.datatype.register_op( + tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm", + "llvm", "bfloat") + +def lower_datatypes_and_build(schedule, args): + """Create schedule and lower, manually lowering datatypes. + + Once datatype lowering is integrated directly into TVM's lower/build + process, we won't need to do this manually. + TODO(gus) integrate datatype lowering into build process; change this test""" + flist = tvm.lower(schedule, args) + flist = [flist] + flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist] + return tvm.build(flist[0], target=tgt) + +def test_bfloat_add_and_cast_1(): + X = tvm.placeholder((3, ), name="X") + Y = tvm.placeholder((3, ), name="Y") + Z = topi.cast( + topi.cast(X, dtype="custom[bfloat]16") + + topi.cast(Y, dtype="custom[bfloat]16"), + dtype="float") + + s = tvm.create_schedule([Z.op]) + built_cast = lower_datatypes_and_build(s, [X,Y,Z]) + + ctx = tvm.context(tgt, 0) + + # Used float32 calculator at http://www.weitz.de/ieee/. Generated float32s + # with at most 7-bit mantissas which, when added, produce a result with at + # most 7-bit mantissas. This is to ensure there are no errors due to + # float32->bfloat16 conversions. + x = tvm.nd.array( + np.array([4.4103796E-32, 14942208.0, 1.78125]).astype("float32"), + ctx=ctx) + y = tvm.nd.array( + np.array([-3.330669E-14, 19660800.0, 2.25]).astype("float32"), ctx=ctx) + z_expected = np.array([-3.330669E-14, 34603008.0, + 4.03125]).astype("float32") + z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) + + built_cast(x, y, z) + + assert np.array_equal(z_expected, z.asnumpy()) + + +def test_bfloat_add_and_cast_2(): + X = tvm.placeholder((3, ), name="X") + Y = tvm.placeholder((3, ), name="Y") + Z = topi.cast( + topi.cast(X, dtype="custom[bfloat]16") + + topi.cast(Y, dtype="custom[bfloat]16"), + dtype="float") + + s = tvm.create_schedule([Z.op]) + built_cast = lower_datatypes_and_build(s, [X,Y,Z]) + + ctx = tvm.context(tgt, 0) + + # Used float32 calculator at http://www.weitz.de/ieee/. Generated + # unconstrained float32s for the operands and copied them in to x and y. + # Then, to simulate float32->bfloat16 conversion implemented by the mybfloat + # library, I cut off all but 7 bits of the mantissa. I then added the + # numbers. To simulate bfloat16 add implemented in mybfloat, I cut off all + # but 7 bits of the result's mantissa. I then copied that value into + # z_expected. + x = tvm.nd.array( + np.array([1.2348297, -1.0298302E25, 1.2034023E-30]).astype("float32"), + ctx=ctx) + y = tvm.nd.array( + np.array([-2.4992788, -9.888288E19, 9.342338E-29]).astype("float32"), + ctx=ctx) + z_expected = np.array([-1.25, -1.027587E25, + 9.426888E-29]).astype("float32") + z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) + + built_cast(x, y, z) + + assert np.array_equal(z_expected, z.asnumpy()) + + +def test_bfloat_add_and_cast_FloatImm(): + X = tvm.placeholder((3, ), name="X") + Z = topi.cast( + topi.add( + topi.cast(X, dtype="custom[bfloat]16"), + tvm.expr.FloatImm("custom[bfloat]16", 1.5)), + dtype="float") + + s = tvm.create_schedule([Z.op]) + built_cast = lower_datatypes_and_build(s, [X,Z]) + + ctx = tvm.context(tgt, 0) + + x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx) + z_expected = np.array([1.5, 2.5, 3.0]).astype("float32") + z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) + + built_cast(x, z) + + assert np.array_equal(z_expected, z.asnumpy()) + + +if __name__ == "__main__": + setup() + test_bfloat_add_and_cast_1() + test_bfloat_add_and_cast_2() + test_bfloat_add_and_cast_FloatImm() From cb91d7e225577ecb14238e24b4fb2db93fee94c5 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Wed, 15 May 2019 17:28:18 -0700 Subject: [PATCH 007/176] [Relay][Compilation] replace relay.build_module with C++ BuildModule (#3174) --- python/tvm/relay/__init__.py | 2 +- python/tvm/relay/_build_module.py | 21 + .../relay/backend/graph_runtime_codegen.py | 18 +- python/tvm/relay/build_module.py | 465 ++++++++---------- python/tvm/relay/quantize/quantize.py | 79 ++- src/codegen/build_module.cc | 2 +- src/relay/backend/build_module.cc | 166 +++---- src/relay/backend/graph_runtime_codegen.cc | 60 +-- tests/cpp/relay_build_module_test.cc | 9 +- .../nnvm_to_relay/test_alter_conv2d.py | 11 +- tests/python/relay/test_cpp_build_module.py | 148 +++--- tests/python/relay/test_pass_annotation.py | 2 +- tests/python/relay/test_pass_quantize.py | 92 ++-- 13 files changed, 534 insertions(+), 541 deletions(-) create mode 100644 python/tvm/relay/_build_module.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 6201681e0294..1f1e4a683ead 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -25,7 +25,7 @@ from . import module from . import adt from . import ir_pass -from .build_module import build, build_config, create_executor, optimize +from .build_module import build, build_config, create_executor from . import prelude from . import parser from . import debug diff --git a/python/tvm/relay/_build_module.py b/python/tvm/relay/_build_module.py new file mode 100644 index 000000000000..bdbcbefff523 --- /dev/null +++ b/python/tvm/relay/_build_module.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface for building Relay functions exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay.build_module", __name__) diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index ea1846b93beb..cf31e9cff833 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -36,12 +36,9 @@ from __future__ import absolute_import from tvm.ndarray import empty -from tvm._ffi.function import _init_api - from tvm.relay import build_module from tvm import target as _target - -_init_api("tvm.relay.build_module") +from tvm import expr as _expr class GraphRuntimeCodegen(object): """The compiler from Relay to the TVM runtime system.""" @@ -57,17 +54,14 @@ def __init__(self, mod, target): self._setup(mod, target) def _setup(self, mod, target): - tgts = [] + tgts = {} if isinstance(target, dict): - for kv in target.items(): - tgts.append(kv[0]) - if isinstance(kv[1], (str, _target.Target)): - tgts.append(str(kv[1])) - else: + for dev, tgt in target.items(): + if not isinstance(tgt, (str, _target.Target)): raise Exception("Unknown target type") + tgts[dev] = _target.create(tgt) elif isinstance(target, (str, _target.Target)): - tgts.append("0") - tgts.append(str(target)) + tgts[_expr.IntImm("int32", 0)] = _target.create(target) self._init(mod, tgts) def codegen(self, func): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c8b69e011543..d0ad78fee67f 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -18,32 +18,19 @@ Construct the necessary state for the TVM graph runtime from a Relay expression. """ -import warnings +import numpy as np from tvm._ffi.runtime_ctypes import TVMContext -from ..build_module import build as _tvm_build_module +from tvm import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt +from . import _build_module from . import ir_pass -from . import expr as _expr from . import ty as _ty +from . import expr as _expr from .backend import interpreter as _interpreter -from .backend import graph_runtime_codegen as _graph_gen from .backend.vm import VMExecutor -# List of optimization pass and level when switch on -OPT_PASS_LEVEL = { - "SimplifyInference": 0, - "OpFusion": 1, - "FoldConstant": 2, - "CombineParallelConv2D": 3, - "FoldScaleAxis": 3, - "AlterOpLayout": 3, - "CanonicalizeOps": 3, - "EliminateCommonSubexpr": 3, -} - - class BuildConfig(object): """Configuration scope to set a build config option. @@ -56,6 +43,7 @@ class BuildConfig(object): defaults = { "opt_level": 2, "add_pass": None, + "disable_pass": None, "fallback_device": None, } @@ -85,23 +73,6 @@ def __exit__(self, ptype, value, trace): assert self._old_scope BuildConfig.current = self._old_scope - def pass_enabled(self, pass_name): - """Get whether pass is enabled. - - Parameters - ---------- - pass_name : str - The optimization pass name - - Returns - ------- - enabled : bool - Whether pass is enabled. - """ - if self.add_pass and pass_name in self.add_pass: - return True - return self.opt_level >= OPT_PASS_LEVEL[pass_name] - BuildConfig.current = BuildConfig() @@ -117,6 +88,9 @@ def build_config(**kwargs): add_pass: set of str Optimization pass to be added regardless of optimization level. + disable_pass: set of str + Optimization pass to be disabled during optimization. + fallback_device : str or tvm.TVMContext The fallback device. It is also used as the default device for operators without specified device during heterogeneous execution. @@ -129,108 +103,203 @@ def build_config(**kwargs): return BuildConfig(**kwargs) -def _bind_params_by_name(func, params): - """Bind parameters of function by its name.""" - name_dict = {} - for arg in func.params: - name = arg.name_hint - if name in name_dict: - name_dict[name] = None - else: - name_dict[name] = arg - bind_dict = {} - for k, v in params.items(): - if k not in name_dict: - continue - arg = name_dict[k] - if arg is None: - raise ValueError("Multiple args in the function have name %s" % k) - bind_dict[arg] = _expr.const(v) - return _expr.bind(func, bind_dict) - - -def optimize(func, target=None, params=None): - """Perform target invariant optimizations. - - Parameters - ---------- - func : tvm.relay.Function - The input to optimization. +def _update_target(target): + target = target if target else _target.current_target() + if target is None: + raise ValueError("Target is not set in env or passed as argument.") - target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]] - The optimization target. For heterogeneous compilation, it is a - dictionary mapping device type to compilation target. For homogeneous - compilation, it is a build target. + tgts = {} + if isinstance(target, (str, _target.Target)): + dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type) + tgts[dev_type] = _target.create(target) + elif isinstance(target, dict): + for dev, tgt in target.items(): + dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type) + tgts[dev_type] = _target.create(tgt) + else: + raise TypeError("target is expected to be str or " + + "tvm.target.Target, but received " + + "{}".format(type(target))) + return tgts - params : Optional[Dict[str, tvm.nd.NDArray]] - Input parameters to the graph that do not change - during inference time. used for constant folding. - Returns - ------- - opt_func : tvm.relay.Function - The optimized version of the function. +class BuildModule(object): + """Build a Relay function to run on TVM graph runtime. This class is used + to expose the `RelayBuildModule` APIs implemented in C++. """ - cfg = BuildConfig.current - - # bind expressions - if params: - func = _bind_params_by_name(func, params) - - if cfg.pass_enabled("SimplifyInference"): - func = ir_pass.infer_type(func) - func = ir_pass.simplify_inference(func) - - if cfg.pass_enabled("EliminateCommonSubexpr"): - def fskip(expr): - if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \ - expr.attrs.dtype == 'int32': - return True - return False - - func = ir_pass.infer_type(func) - func = ir_pass.eliminate_common_subexpr(func, fskip) - - if cfg.pass_enabled("CombineParallelConv2D"): - func = ir_pass.infer_type(func) - func = ir_pass.combine_parallel_conv2d(func) - - # The constant folding pass is necessary because FoldScaleAxis pass needs - # to check the constantness and positiveness of scales. - if cfg.pass_enabled("FoldConstant"): - func = ir_pass.fold_constant(func) - - if cfg.pass_enabled("FoldScaleAxis"): - func = ir_pass.infer_type(func) - func = ir_pass.backward_fold_scale_axis(func) - func = ir_pass.infer_type(func) - func = ir_pass.forward_fold_scale_axis(func) - func = ir_pass.fold_constant(func) - - if cfg.pass_enabled("CanonicalizeOps"): - func = ir_pass.infer_type(func) - func = ir_pass.canonicalize_ops(func) - - # FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for - # now. We probably need to pass target to this pass as well. Fix it in - # a followup PR. - if cfg.pass_enabled("AlterOpLayout"): - if isinstance(target, _target.Target): - func = ir_pass.infer_type(func) - with target: - func = ir_pass.alter_op_layout(func) - elif isinstance(target, dict): - warnings.warn("AlterOpLayout pass is not enabled for heterogeneous" - " execution yet.") - - if cfg.pass_enabled("FoldConstant"): - func = ir_pass.fold_constant(func) - - return func + def __init__(self): + self.mod = _build_module._BuildModule() + self._get_graph_json = self.mod["get_graph_json"] + self._get_module = self.mod["get_module"] + self._build = self.mod["build"] + self._add_pass = self.mod["add_pass"] + self._disable_pass = self.mod["disable_pass"] + self._set_opt_level = self.mod["set_opt_level"] + self._set_fallback_device = self.mod["set_fallback_device"] + self._set_params_func = self.mod["set_params"] + self._get_params_func = self.mod["get_params"] + + def build(self, func, target=None, target_host=None, params=None): + """ + Parameters + ---------- + func: relay.Function + The function to build. + + target : str, :any:`tvm.target.Target`, or dict of str(i.e. + device/context name) to str/tvm.target.Target, optional + For heterogeneous compilation, it is a dictionary indicating context + to target mapping. For homogeneous compilation, it is a build target. + + target_host : str or :any:`tvm.target.Target`, optional + Host compilation target, if target is device. + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + target_host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + + Returns + ------- + graph_json : str + The json string that can be accepted by graph runtime. + + mod : tvm.Module + The module containing necessary libraries. + + params : dict + The parameters of the final graph. + """ + target = _update_target(target) + + # Setup the build configurations passed in through `with build_config`. + self._setup_build_config(params) + # Build the function + self._build(func, target, target_host) + # Get artifacts + graph_json = self.get_json() + mod = self.get_module() + params = self.get_params() + + return graph_json, mod, params + + def _setup_build_config(self, params): + cfg = BuildConfig.current + + # Set opt_level. + self.set_opt_level(cfg.opt_level) + + # Set fallback device if it is available. + if cfg.fallback_device: + self.set_fallback_device(cfg.fallback_device) + + # Add required passes. + if cfg.add_pass: + passes = set() + if isinstance(cfg.add_pass, (list, tuple, set)): + passes = set(cfg.add_pass) + else: + raise TypeError("add_pass must be list, tuple, or set, but " + + "got {}".format(type(cfg.add_pass))) + for pass_name in passes: + self.add_pass(pass_name) + + # Add disabled passes. + if cfg.disable_pass: + passes = set() + if isinstance(cfg.disable_pass, (list, tuple, set)): + passes = set(cfg.disable_pass) + else: + raise TypeError("disable_pass must be list, tuple, or set, " + + "but got {}".format(type(cfg.disable_pass))) + for pass_name in passes: + self.disable_pass(pass_name) + + if params: + self._set_params(params) + + def _set_params(self, params): + inputs = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = _nd.array(param) + inputs[name] = _expr.const(param) + self._set_params_func(inputs) + + def add_pass(self, pass_name): + """Add a pass to the pass list. + + Parameters + ---------- + pass_name : str + The name of the pass that will be added to the list of passes used + for optimizations. + """ + self._add_pass(pass_name) + + def disable_pass(self, pass_name): + """Add a pass to the disabled pass list. + + Parameters + ---------- + pass_name : str + The name of a pass. This pass will be added to the list of passes + that are disabled during optimization. + """ + self._disable_pass(pass_name) + + def get_json(self): + """Return the json file of the built program.""" + return self._get_graph_json() + + def get_module(self): + """Return the built module.""" + return self._get_module() + + def get_params(self): + """Return the updated weights.""" + params = self._get_params_func() + ret = {} + for key, value in params.items(): + ret[key] = value.data + return ret + + def set_opt_level(self, level): + """Set the optimization level. + + Parameters + ---------- + level : int + The optimization level for build. + """ + self._set_opt_level(level) + + def set_fallback_device(self, fallback_device): + """Set the fallback device for heterogeneous execution. + + Parameters + ---------- + fallback_device : str or tvm.TVMContext + The fallback device used for heterogeneous execution. + """ + if isinstance(fallback_device, str): + fallback_device = _nd.context(fallback_device) + if not isinstance(fallback_device, TVMContext): + raise TypeError("fallback_device is expected to be str " + + "TVMContext, or dict of device name to target, " + + "but received: {}".format(type(fallback_device))) + + self._set_fallback_device(fallback_device.device_type) def build(func, target=None, target_host=None, params=None): - """Build a function to run on TVM graph runtime. + """Helper function that builds a Relay function to run on TVM graph + runtime. Parameters ---------- @@ -266,146 +335,28 @@ def build(func, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - target = target if target else _target.current_target() - if target is None: - raise ValueError("Target is not set in env or passed as argument.") + target = _update_target(target) - if isinstance(target, dict): - target, fallback_device = _update_heterogeneous_inputs(target) - elif isinstance(target, (str, _target.Target)): - target = _target.create(target) - else: - raise ValueError("target must be the type of str, tvm.target.Target," + - "or dict of device name to target") + if isinstance(target_host, (str, _target.Target)): + target_host = _target.create(target_host) + elif target_host: + raise ValueError("target host must be the type of str, " + + "tvm.target.Target, or None") # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - if isinstance(target, dict): - tophub_context = autotvm.tophub.context(list(target.values())) - else: - tophub_context = autotvm.tophub.context(target) + tophub_context = autotvm.tophub.context(list(target.values())) else: tophub_context = autotvm.util.EmptyContext() - cfg = BuildConfig.current - with tophub_context: - func = optimize(func, target, params) - # Annotate the ops for heterogeneous execution. - if isinstance(target, dict): - func, target = _run_device_annotation_passes(func, target, - fallback_device) - # Fuse ops before running code gen - func = ir_pass.infer_type(func) - func = ir_pass.fuse_ops(func, cfg.opt_level) - # Graph code generation - func = ir_pass.infer_type(func) - graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) - graph_json, lowered_funcs, params = graph_gen.codegen(func) - mod = _tvm_build_module( - lowered_funcs, target=target, target_host=target_host) + bld_mod = BuildModule() + graph_json, mod, params = bld_mod.build(func, target, target_host, + params) return graph_json, mod, params -def _update_heterogeneous_inputs(target): - """Update the target and fallback device required for heterogeneous - compilation. CPU is used as the fallback device if it wasn't provided. - Meanwhile, a CPU device type and "llvm" pair will be added to the target - dictionary in this case. - - Parameters - ---------- - target : dict of str(i.e. device/context name) to str/tvm.target.Target. - A dict contains context to target pairs. - - Returns - ------- - device_target : dict of int to tvm.target.Target. - The updated device type to target dict. - - fallback_device : int - The updated fallback device type. - """ - if not isinstance(target, dict): - raise ValueError("target must be dict of device name to target for " + - "heterogeneous execution, but received %s." - % type(target)) - - fallback_device = BuildConfig.current.fallback_device - if fallback_device is None: - # cpu is used as the default fallback device when heterogeneous - # execution is needed, but no fallback device is provided. - fallback_device = _nd.cpu(0).device_type - target[fallback_device] = str(_target.create("llvm")) - elif isinstance(fallback_device, str): - fallback_device = _nd.context(fallback_device).device_type - elif isinstance(fallback_device, TVMContext): - fallback_device = fallback_device.device_type - else: - raise ValueError("fallback_device expects the type of str or " + - "TVMContext, but received %s." % type(fallback_device)) - - device_target = {} - for dev, tgt in target.items(): - device_target[_nd.context(dev).device_type] = _target.create(tgt) - - if fallback_device not in device_target: - raise ValueError("%s is used as the default device, but the target" + - "is not provided." - % _nd.context(fallback_device).device_name) - return device_target, fallback_device - - -def _run_device_annotation_passes(func, target, fallback_device): - """Execute the device annotation passes to update the input program and - target information. - - Parameters - ---------- - func: tvm.relay.Function - The function where annotation passes will be execute at. - - target : Dict[int, tvm.target.Target] - A dict contains device type to target pairs. - - fallback_device : int - The fallback device type. - - Returns - ------- - target : Dict[int, tvm.target.Target] - The updated device type to target dict. - - func : tvm.relay.Function - The updated func. - """ - func = ir_pass.infer_type(func) - func = ir_pass.rewrite_annotated_ops(func, fallback_device) - device_map = ir_pass.collect_device_info(func) - # The expression to device type map will be empty if all or none of - # the expressions in the `func` are annotated because this map is - # obtained by propagating the device information in the device copy - # operator. None of the above cases needs device copy operator. - if not device_map: - annotation_map = ir_pass.collect_device_annotation_ops(func) - # No annotation. - if not annotation_map: - target = {0: target[fallback_device]} - else: - dev_type = next(iter(annotation_map.values())) - # All annotated with the same device type. - if all(val == dev_type for val in annotation_map.values()): - target = {0: target[dev_type]} - else: - raise RuntimeError("Expressions in the function are " - "annotated with various device types," - "but not device copy operators " - "found. Please check the " - "RewriteAnnotation pass.") - return func, target - - class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 607ee1821c86..b84d3eb40037 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -269,6 +269,77 @@ def realize(graph): return _quantize.realize(graph) +def optimize(func, params=None): + """ Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and + "CanonicalizeOps" optimization before quantization. + + # TODO(zhiics) These passes are executed one by one so far. We need to + # move them to the pass manager. + + Parameters + --------- + func: tvm.relay.Function + The original Relay function to be optimized. + + params : dict of str to tvm.NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + + Returns + ------- + ret: tvm.relay.Function + The graph after quantization + """ + + opt_passes = ["SimplifyInference", + "FoldScaleAxis", + "FoldConstant", + "CanonicalizeOps"] + + cfg = _build.build_config(add_pass=opt_passes) + + if params: + name_dict = {} + for arg in func.params: + name = arg.name_hint + if name in name_dict: + name_dict[name] = None + else: + name_dict[name] = arg + bind_dict = {} + for k, v in params.items(): + if k not in name_dict: + continue + arg = name_dict[k] + if arg is None: + raise ValueError("Multiple args in the function have name %s" % k) + bind_dict[arg] = _expr.const(v) + func = _expr.bind(func, bind_dict) + + if "SimplifyInference" in cfg.add_pass: + func = _ir_pass.infer_type(func) + func = _ir_pass.simplify_inference(func) + + if "FoldConstant" in cfg.add_pass: + func = _ir_pass.fold_constant(func) + + if "FoldScaleAxis" in cfg.add_pass: + func = _ir_pass.infer_type(func) + func = _ir_pass.backward_fold_scale_axis(func) + func = _ir_pass.infer_type(func) + func = _ir_pass.forward_fold_scale_axis(func) + func = _ir_pass.fold_constant(func) + + if "CanonicalizeOps" in cfg.add_pass: + func = _ir_pass.infer_type(func) + func = _ir_pass.canonicalize_ops(func) + + if "FoldConstant" in cfg.add_pass: + func = _ir_pass.fold_constant(func) + + return func + + def quantize(graph, params=None, dataset=None): """ The quantization procedure. Before running the three main procedure of quantization, "annotate", "calibrate" and "realize" @@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None): ret: Function The graph after quantization """ - opt_passes = ["SimplifyInference", - "FoldScaleAxis", - "FoldConstant", - "CanonicalizeOps"] - with _build.build_config(add_pass=opt_passes): - graph = _build.optimize(graph, params=params) + # TODO(zhiics) Move this to the pass manager. + graph = optimize(graph, params) graph = annotate(graph) graph = calibrate(graph, dataset) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 57e300fafec2..9b30ced90c4f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -311,7 +311,7 @@ bool LLVMEnabled() { /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { - if (target->device_type == kDLCPU) { + if (target.defined() && target->device_type == kDLCPU) { return target; } else { if (LLVMEnabled()) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 08a88d53350f..63ee2d59d854 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -38,54 +38,31 @@ namespace tvm { namespace relay { namespace backend { +using TargetsMap = Map; + /*! - * \brief Context name / index - * See: python/tvm/_ffi/runtime_ctypes.py + * \brief Context index to Target */ -struct ContextMap { - static const std::unordered_map mask2str; - static const std::unordered_map str2mask; - static std::string Mask2Str(int mask) { +struct ContextTargetMap { + static const std::unordered_map mask2str; + static tvm::Target Mask2Str(int mask) { CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; return mask2str.at(mask); } - static int Str2Mask(const std::string& str) { - CHECK_GT(str2mask.count(str), 0) << "Unknown context."; - return str2mask.at(str); - } -}; - -const std::unordered_map ContextMap::mask2str = { - {1, "cpu"}, - {2, "gpu"}, - {4, "opencl"}, - {5, "aocl"}, - {6, "sdaccel"}, - {7, "vulkan"}, - {8, "metal"}, - {9, "vpi"}, - {10, "rocm"}, - {11, "opengl"}, - {12, "ext_dev"} }; -const std::unordered_map ContextMap::str2mask = { - {"llvm", 1}, - {"cpu", 1}, - {"c", 1}, - {"gpu", 2}, - {"cuda", 2}, - {"nvptx", 2}, - {"cl", 4}, - {"opencl", 4}, - {"aocl", 5}, - {"aocl_sw_emu", 5}, - {"vulkan", 7}, - {"metal", 8}, - {"vpi", 9}, - {"rocm", 10}, - {"opengl", 11}, - {"ext_dev", 12} +const std::unordered_map ContextTargetMap::mask2str = { + {1, tvm::Target::create("llvm")}, + {2, tvm::Target::create("cuda")}, + {4, tvm::Target::create("opencl")}, + {5, tvm::Target::create("aocl")}, + {6, tvm::Target::create("sdaccel")}, + {7, tvm::Target::create("vulkan")}, + {8, tvm::Target::create("metal")}, + {9, tvm::Target::create("vpi")}, + {10, tvm::Target::create("rocm")}, + {11, tvm::Target::create("opengl")}, + {12, tvm::Target::create("ext_dev")} }; /*! @@ -137,7 +114,7 @@ struct BuildOutput { */ struct RelayBuildConfig { int opt_level{2}; - std::string fallback_device{"llvm"}; + int fallback_device{static_cast(kDLCPU)}; std::unordered_set enabled_pass; std::unordered_set disabled_pass; OptPassLevel OPT_PASS_LEVEL; @@ -164,14 +141,8 @@ struct GraphCodegen { } ~GraphCodegen() {} - void Init(runtime::Module* m, - Map targets) { - Array tgts; - for (auto kv : targets) { - tgts.push_back(kv.first); - tgts.push_back(kv.second); - } - CallFunc("init", m, tgts); + void Init(runtime::Module* m, TargetsMap targets) { + CallFunc("init", m, targets); } void Codegen(const Function& func) { @@ -248,14 +219,7 @@ class RelayBuildModule : public runtime::ModuleNode { } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); - Array tmp = args[1]; - std::unordered_map targets; - for (size_t i = 0; i < tmp.size(); i += 2) { - auto k = tmp[i].as()->value; - auto v = tmp[i + 1].as()->value; - targets[k] = v; - } - this->Build(args[0], targets, args[2]); + this->Build(args[0], args[1], args[2]); }); } else if (name == "list_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -273,7 +237,8 @@ class RelayBuildModule : public runtime::ModuleNode { }); } else if (name == "set_fallback_device") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string dev = args[0]; + CHECK_EQ(args.num_args, 1); + int dev = args[0]; this->SetFallBackDev(dev); }); } else if (name == "add_pass") { @@ -328,7 +293,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \param device name */ - void SetFallBackDev(const std::string& dev) { + void SetFallBackDev(int dev) { cfg_.fallback_device = dev; } /*! @@ -402,8 +367,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target_host Host target device */ void Build(Function func, - const std::unordered_map& targets, - const std::string& target_host) { + const TargetsMap& targets, + const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; BuildRelay(func, cfg_, params_); @@ -416,8 +381,9 @@ class RelayBuildModule : public runtime::ModuleNode { * \param params params dict * \return relay::Function */ - relay::Function BindParamsByName(relay::Function func, - const std::unordered_map& params) { + relay::Function BindParamsByName( + relay::Function func, + const std::unordered_map& params) { std::unordered_map name_dict; std::unordered_set repeat_var; for (auto arg : func->params) { @@ -454,7 +420,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \return relay::Function */ relay::Function Optimize(relay::Function func, - const std::unordered_map& targets, + const TargetsMap& targets, const RelayBuildConfig& cfg, const std::unordered_map& params) { if (params.size()) { @@ -507,8 +473,7 @@ class RelayBuildModule : public runtime::ModuleNode { auto enter_pf = GetPackedFunc("_EnterTargetScope"); auto exit_pf = GetPackedFunc("_ExitTargetScope"); for (const auto& kv : targets) { - auto target = Target::create(kv.second); - (*enter_pf)(target); + (*enter_pf)(kv.second); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); (*exit_pf)(); } @@ -530,25 +495,19 @@ class RelayBuildModule : public runtime::ModuleNode { * * \param targets dictionary * \param cfg - * \return Map + * \return Map */ - Map UpdateHeterogeneousInputs( - const std::unordered_map& targets, - const RelayBuildConfig& cfg) { - Map device_target; - std::unordered_map tmp_map; - auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); - + TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets, + const RelayBuildConfig& cfg) { + TargetsMap device_target = targets; + std::unordered_map tmp_map; for (const auto& kv : targets) { - tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second; - } - if (tmp_map.count(fallback_idx) == 0) { - tmp_map[fallback_idx] = cfg.fallback_device; + tmp_map[kv.first->value] = kv.second; } - for (const auto& kv : tmp_map) { + if (tmp_map.count(cfg.fallback_device) == 0) { device_target.Set( - ir::IntImm::make(HalideIR::Int(64), kv.first), - ir::StringImm::make(kv.second)); + cfg.fallback_device, + ContextTargetMap::Mask2Str(cfg.fallback_device)); } return device_target; } @@ -561,25 +520,19 @@ class RelayBuildModule : public runtime::ModuleNode { * \param targets_map_ptr * \return Function */ - Function RunDeviceAnnotationPass( - Function func, - const RelayBuildConfig& cfg, - Map* targets_map_ptr) { - auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); + Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg, + TargetsMap* targets_map_ptr) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); - auto device_map = CallPackedFunc >("relay._ir_pass.CollectDeviceInfo", - func, - nullptr); + func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, + cfg.fallback_device); + auto device_map = CallPackedFunc >( + "relay._ir_pass.CollectDeviceInfo", func, nullptr); if (device_map.size() == 0) { - auto annotation_map = - CallPackedFunc >("relay._ir_pass.CollectDeviceAnnotationOps", - func, - nullptr); + auto annotation_map = CallPackedFunc >( + "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); if (annotation_map.size() == 0) { targets_map_ptr->Set( - ir::IntImm::make(HalideIR::Int(64), 0), - ir::StringImm::make(cfg.fallback_device)); + 0, ContextTargetMap::Mask2Str(cfg.fallback_device)); } else { int64_t dev_type = -1; for (auto kv : annotation_map) { @@ -594,9 +547,7 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set( - ir::IntImm::make(HalideIR::Int(64), 0), - ir::StringImm::make(ContextMap::Mask2Str(dev_type))); + targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type)); } } return func; @@ -614,15 +565,11 @@ class RelayBuildModule : public runtime::ModuleNode { const std::unordered_map ¶ms) { // convert tvm_cfg_ = build_config(); - Map device_target; + TargetsMap device_target; if (targets_.size() > 1) { device_target = UpdateHeterogeneousInputs(targets_, cfg); } else { - for (auto &kv : targets_) { - device_target.Set( - ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)), - ir::StringImm::make(kv.second)); - } + device_target = targets_; } func = Optimize(func, targets_, cfg, params); if (device_target.size() > 1) { @@ -640,16 +587,15 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - auto target_host = Target::create(target_host_); - ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_); + ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_); } protected: std::unique_ptr graph_codegen_; /*! \brief target device */ - std::unordered_map targets_; + TargetsMap targets_; /*! \brief target host device */ - std::string target_host_; + tvm::Target target_host_; /*! \brief frontend optimization configure */ RelayBuildConfig cfg_; /*! \brief parameters */ diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 415e0ec9c2a5..b14448c59166 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map; using GraphNodePtr = std::shared_ptr; using GraphInputNodePtr = std::shared_ptr; using GraphOpNodePtr = std::shared_ptr; -using TargetsMap = std::unordered_map; +using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ struct LoweredOutput { @@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode { class GraphRuntimeCodegen : public ::tvm::relay::ExprFunctor(const Expr&)> { public: - GraphRuntimeCodegen(runtime::Module* mod, - const std::unordered_map& targets) : mod_(mod) { + GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) + : mod_(mod) { compile_engine_ = CompileEngine::Global(); - for (auto &kv : targets) { - targets_[kv.first] = Target::create(kv.second); - } + targets_ = targets; } LoweredOutput Codegen(relay::Function func) { @@ -406,7 +404,7 @@ class GraphRuntimeCodegen auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); auto &device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; //-> int to string + auto call_dev_type = device_type[0]->value; Target target; if (targets_.size() == 1) { // homogeneous execution. @@ -415,22 +413,17 @@ class GraphRuntimeCodegen } } else { // heterogeneous execution. - const auto call_dev_key = std::to_string(call_dev_type); std::string call_dev_name; if (call_dev_type == 0) { call_dev_name = "llvm"; } else { call_dev_name = runtime::DeviceName(call_dev_type); } - if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { + if (targets_.count(call_dev_type) == 0) { LOG(FATAL) << "No target is provided for device " << call_dev_name; } - if (targets_.count(call_dev_key)) { - target = targets_[call_dev_key]; - } else { - target = targets_[call_dev_name]; - } + target = targets_[call_dev_type]; } CCacheKey key = (*pf0)(func, target); CachedFunc lowerd_func = (*pf1)(compile_engine_, key); @@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { virtual PackedFunc GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; - void* mod = args[0]; - auto& sptr = args[1].node_sptr(); - auto* node = static_cast(sptr.get()); - auto& tmp_targets = node->data; - std::unordered_map targets; - for (size_t i = 0; i < tmp_targets.size(); i += 2) { - std::string key; - auto sk = Expr(tmp_targets[i]).as(); - auto ik = Expr(tmp_targets[i]).as(); - if (sk) { - key = sk->value; - } - if (ik) { - key = std::to_string(ik->value); - } - auto v = Expr(tmp_targets[i + 1]).as(); - targets[key] = v->value; - } - codegen_ = std::make_shared( - reinterpret_cast(mod), targets); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 2) + << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map tmp = args[1]; + TargetsMap targets; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + CHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = std::make_shared( + reinterpret_cast(mod), targets); + }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 38481bfb8204..a1ab29959127 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -73,10 +74,10 @@ TEST(Relay, BuildModule) { auto build_f = build_mod.GetFunction("build", false); auto json_f = build_mod.GetFunction("get_graph_json", false); auto mod_f = build_mod.GetFunction("get_module", false); - Array target_pair; - target_pair.push_back(ir::StringImm::make("cpu")); - target_pair.push_back(ir::StringImm::make("llvm")); - build_f(func, target_pair, "llvm"); + Map targets; + Target llvm_tgt = Target::create("llvm"); + targets.Set(0, llvm_tgt); + build_f(func, targets, llvm_tgt); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run diff --git a/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py b/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py index a03868550160..d3538bb0085b 100644 --- a/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py +++ b/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py @@ -74,13 +74,12 @@ def convnet(): for tgt in targets: with tvm.target.create(tgt) as target: - with relay.build_config(opt_level=-1, add_pass='AlterOpLayout'): - with autotvm.tophub.context(target): - O = relay.optimize(N, target, params=None) - O = relay.ir_pass.infer_type(O) + with autotvm.tophub.context(target): + O = relay.ir_pass.alter_op_layout(N) + O = relay.ir_pass.infer_type(O) - # graph should differ - assert not relay.ir_pass.alpha_equal(N, O) + # graph should differ + assert not relay.ir_pass.alpha_equal(N, O) if __name__ == "__main__": np.random.seed(42) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index b94f57d77286..affc6ce04c6b 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -18,55 +18,10 @@ import tvm from tvm import relay +from tvm.contrib.nvcc import have_fp16 -from tvm._ffi.function import _init_api -_init_api("tvm.relay.build_module") - -class BuildModule(object): - def __init__(self): - self.mod = relay.build_module._BuildModule() - self._get_graph_json = self.mod["get_graph_json"] - self._get_module = self.mod["get_module"] - self._build = self.mod["build"] - self._set_opt_level = self.mod["set_opt_level"] - self._set_params_func = self.mod["set_params"] - self._get_params_func = self.mod["get_params"] - - - def build(self, func, target, target_host, params): - tgts = [] - for kv in target.items(): - tgts.append(kv[0]) - tgts.append(kv[1]) - self._set_params(params) - self._build(func, tgts, target_host) - - def get_json(self): - return self._get_graph_json() - - def get_module(self): - return self._get_module() - - def set_opt_level(self, level): - self._set_opt_level(level) - - def _set_params(self, params): - inputs = {} - for name, param in params.items(): - inputs[name] = relay.Constant(param) - self._set_params_func(inputs) - - def get_params(self): - params = self._get_params_func() - ret = {} - for key, value in params.items(): - ret[key] = value.data - return ret - - -def test_build(): - m_bld = BuildModule() - tgt_name = "llvm" + +def test_basic_build(): tgt = "llvm" ctx = tvm.cpu() # func @@ -86,21 +41,96 @@ def test_build(): } # build targets = { - tgt: tgt + tvm.expr.IntImm("int32", ctx.device_type): tgt } - m_bld.set_opt_level(3) - m_bld.build(func, targets, "llvm", params=params) - g_json = m_bld.get_json() - mmod = m_bld.get_module() - params = m_bld.get_params() - + g_json, mmod, params = relay.build(func, targets, "llvm", params=params) + # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) rt.set_input("a", A) rt.load_params(relay.save_param_dict(params)) rt.run() out = rt.get_output(0) - - np.testing.assert_allclose(out.asnumpy(), - np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) - + + np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(), + B.asnumpy().T), + 0) + C.asnumpy(), + atol=1e-5, rtol=1e-5) + + +def test_fp16_build(): + dtype = "float16" + + if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist: + print("skip because cuda is not enabled.") + return + + ctx = tvm.gpu(0) + if dtype == "float16" and not have_fp16(ctx.compute_version): + print("skip because gpu does not support fp16") + return + + x = relay.var("x", dtype=dtype, shape=(4, 4)) + y = relay.var("y", dtype=dtype, shape=(4, 4)) + z = x + y + func = relay.Function([x, y], z) + X = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx) + Y = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx) + params = { + "x": X, + "y": Y, + } + + # build + g_json, mmod, params = relay.build(func, "cuda", params=params) + + # test + rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) + rt.load_params(relay.save_param_dict(params)) + rt.run() + out = rt.get_output(0) + + np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(), + atol=1e-5, rtol=1e-5) + + +def test_fp16_conversion(): + def check_conversion(tgt, ctx): + if not tvm.module.enabled(tgt): + print("skip because {} is not enabled.".format(tgt)) + return + elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version): + print("skip because gpu does not support fp16") + return + + n = 10 + + for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]: + x = relay.var("x", relay.TensorType((n,), src)) + y = x.astype(dst) + func = relay.Function([x], y) + + # init input + X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2) + + # build + with relay.build_config(opt_level=1): + g_json, mmod, params = relay.build(func, tgt) + + # test + rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) + rt.set_input("x", X) + rt.run() + out = rt.get_output(0) + + np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst), + atol=1e-5, rtol=1e-5) + + for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]: + check_conversion(target, ctx) + + +if __name__ == "__main__": + test_basic_build() + test_fp16_build() + test_fp16_conversion() diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 04081e06735b..9a77d2ffe856 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -411,7 +411,7 @@ def expected(): expected_index) def test_fallback_all_operators(device, tgt): - target = {device: tgt} + target = {device: tgt, "cpu": "llvm"} annotated_func = get_func() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py index a9a683c43263..1630efce7f6c 100644 --- a/tests/python/relay/test_pass_quantize.py +++ b/tests/python/relay/test_pass_quantize.py @@ -47,54 +47,54 @@ def test_simulated_quantize(): assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32") -# def test_quantize_pass(): -# def quantize_weight(arr): -# maximum = np.amax(np.abs(arr.asnumpy())) -# scale = 2**math.ceil(math.log(maximum, 2)) -# out = np.around(arr.asnumpy() / scale * 128).astype('int8') -# out = np.clip(out, -127, 127) -# return relay.const(out, 'int8') -# -# n, c, h, w = 1, 3, 224, 224 -# def make_graph(data): -# weight = relay.var("conv_weight") -# out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) -# out = relay.Function(relay.ir_pass.free_vars(out), out) -# return out -# -# def make_qgraph(data, weight): -# out = data * relay.const(32.0) -# out = relay.round(out) -# out = relay.clip(out, a_min=-127, a_max=127) -# out = out.astype('int8') -# -# out = relay.nn.conv2d(out, weight, kernel_size=(3, 3), -# padding=(1, 1), channels=c, out_dtype='int32') -# out = out.astype('float32') -# out = relay.multiply(out, relay.const(0.00024414062)) -# out = relay.Function(relay.ir_pass.free_vars(out), out) -# return out -# -# data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) -# graph = make_graph(data) -# dataset, params = make_dataset(graph, 10) -# -# with qtz.qconfig(skip_k_conv=0, global_scale=4.0, -# round_for_shift=False, store_lowbit_output=False): -# qgraph0 = qtz.quantize(graph, params) -# qgraph0 = relay.ir_pass.infer_type(qgraph0) -# -# conv_weight = quantize_weight(params['conv_weight']) -# qgraph1 = make_qgraph(data, conv_weight) -# qgraph1 = relay.ir_pass.infer_type(qgraph1) -# -# graph = relay.create_executor('graph') -# res0 = graph.evaluate(qgraph0)(dataset[0]['data']) -# res1 = graph.evaluate(qgraph1)(dataset[0]['data']) -# tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3) +def test_quantize_pass(): + def quantize_weight(arr): + maximum = np.amax(np.abs(arr.asnumpy())) + scale = 2**math.ceil(math.log(maximum, 2)) + out = np.around(arr.asnumpy() / scale * 128).astype('int8') + out = np.clip(out, -127, 127) + return relay.const(out, 'int8') + + n, c, h, w = 1, 3, 224, 224 + def make_graph(data): + weight = relay.var("conv_weight") + out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) + out = relay.Function(relay.ir_pass.free_vars(out), out) + return out + + def make_qgraph(data, weight): + out = data * relay.const(32.0) + out = relay.round(out) + out = relay.clip(out, a_min=-127, a_max=127) + out = out.astype('int8') + + out = relay.nn.conv2d(out, weight, kernel_size=(3, 3), + padding=(1, 1), channels=c, out_dtype='int32') + out = out.astype('float32') + out = relay.multiply(out, relay.const(0.00024414062)) + out = relay.Function(relay.ir_pass.free_vars(out), out) + return out + + data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) + graph = make_graph(data) + dataset, params = make_dataset(graph, 10) + + with qtz.qconfig(skip_k_conv=0, global_scale=4.0, + round_for_shift=False, store_lowbit_output=False): + qgraph0 = qtz.quantize(graph, params) + qgraph0 = relay.ir_pass.infer_type(qgraph0) + + conv_weight = quantize_weight(params['conv_weight']) + qgraph1 = make_qgraph(data, conv_weight) + qgraph1 = relay.ir_pass.infer_type(qgraph1) + + graph = relay.create_executor('graph') + res0 = graph.evaluate(qgraph0)(dataset[0]['data']) + res1 = graph.evaluate(qgraph1)(dataset[0]['data']) + tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3) if __name__ == "__main__": np.random.seed(42) test_simulated_quantize() - # test_quantize_pass() + test_quantize_pass() From e5707676615688a3a36ff532f924493153431526 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 15 May 2019 19:03:40 -0700 Subject: [PATCH 008/176] [Relay] Option to select which convolution layers are quantized. (#3173) * Stashing for later maybe. * Added new option to leave specific layers unquantized. * Better error checking. * remove unneeded import * tab to spaces * pylint fixes * more pylint fixes --- python/tvm/relay/quantize/_annotate.py | 49 +++++++++++++++++++++++--- python/tvm/relay/quantize/quantize.py | 5 +++ src/relay/pass/quantize.cc | 1 + src/relay/pass/quantize.h | 2 ++ topi/python/topi/cuda/conv2d.py | 4 ++- 5 files changed, 55 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index e52ce142e5c3..9bf546fcdadf 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx): if cnt < current_qconfig().skip_k_conv: _set_conv_counter(cnt + 1) return None + + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices: + _set_conv_counter(cnt + 1) + return None + _set_conv_counter(cnt + 1) lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx): cnt = _conv_counter() if cnt < current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt - 1 in leave_alone_indices: + return None + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): """Rewrite function for multiply.""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt - 1 in leave_alone_indices: + return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx): @register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): """Rewrite function for add.""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt - 1 in leave_alone_indices: + return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx): def identity_rewrite(ref_call, new_args, ctx): """Simply forward the original operation""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt - 1 in leave_alone_indices: + return None x_expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -262,8 +290,14 @@ def identity_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt - 1 in leave_alone_indices: + return None + expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -280,8 +314,13 @@ def pool2d_rewrite(ref_call, new_args, ctx): @register_annotate_function("concatenate") def concatenate_rewrite(ref_call, new_args, ctx): """Rewrite function for concatenate""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt - 1 in leave_alone_indices: + return None input_tuple = new_args[0] expr_list = [_get_expr_kind(x)[0] for x in input_tuple] diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index b84d3eb40037..7fd0099e64a2 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -71,6 +71,7 @@ class QConfig(NodeBase): "dtype_activation": "int32", "global_scale": 8.0, "skip_k_conv": 1, + "skip_conv_layers": None, "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, @@ -139,6 +140,10 @@ def qconfig(**kwargs): skip_k_conv: int The number of skipped conv2d. + skip_conv_layers: list + Different way of specifying which layers to avoid. Provide a list of indices + that indicate which conv2d layers to leave untouched. + round_for_shift: boolean Whether to add bias for rounding during shift. diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 7fd27b46ad6a..3a2e54c8ad39 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; + p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 4d26dd6be4a5..2c70da177199 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -126,6 +126,7 @@ class QConfigNode : public Node { DataType dtype_activation = Int(32); double global_scale = 8.0; int skip_k_conv = 1; + Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; bool store_lowbit_output = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); @@ -140,6 +141,7 @@ class QConfigNode : public Node { v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); v->Visit("skip_k_conv", &skip_k_conv); + v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 256f91567b71..220235e56a33 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -104,7 +104,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed=False) if cfg.template_key == 'int8': - return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + if (data.dtype == 'int8' or data.dtype == 'uint8'): + return conv2d_NCHWc_int8( + cfg, data, kernel, strides, padding, dilation, layout, out_dtype) if layout == 'NCHW': return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) From 73af04f06b7bf1606893e0a637b2dd39b10be425 Mon Sep 17 00:00:00 2001 From: llyfacebook <34827865+llyfacebook@users.noreply.github.com> Date: Wed, 15 May 2019 20:21:35 -0700 Subject: [PATCH 009/176] Add the acc16 intrinsic support (#3081) --- tests/python/contrib/test_gemm_acc16.py | 90 +++++++++++++++++++++++++ topi/python/topi/x86/tensor_intrin.py | 79 ++++++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 tests/python/contrib/test_gemm_acc16.py diff --git a/tests/python/contrib/test_gemm_acc16.py b/tests/python/contrib/test_gemm_acc16.py new file mode 100644 index 000000000000..0fc5e1a9a3fa --- /dev/null +++ b/tests/python/contrib/test_gemm_acc16.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition +import tvm +import numpy as np +from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int16 + + +def benchmark_fc_int8_acc16(): + m = 128 + n = 128 + k = 128 + + X = tvm.placeholder((m, k), name='X', dtype="uint8") + W = tvm.placeholder((n, k), name='W', dtype="int8") + + peak = 512/16*2*2*2 + gops_per_mm = 2*n*m*k + print("Peak {} Gops/s \n".format(peak)) + + def verify(target="llvm -mcpu=skylake-avx512"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + + ctx = tvm.context(target, 0) + X = tvm.placeholder((m, k), name='X', dtype="uint8") + W = tvm.placeholder((n, k), name='W', dtype="int8") + pc = dot_16x1x16_int8_int8_int16() + ak = tvm.reduce_axis((0, k), name='k') + + packedW = tvm.placeholder((n/128, 128*(k/2), 2), name='packedW', dtype="int8") + t_fc = tvm.compute((m, n), lambda i, j: tvm.sum(X[i, ak].astype("int16") * packedW[j/128, (ak/2)*128+j%128, ak%2].astype("int16"), axis=ak), name="F") + + t_sch = tvm.create_schedule(t_fc.op) + a_x, a_y = t_fc.op.axis + a_k, = t_fc.op.reduce_axis + + a_yo, a_yi = t_sch[t_fc].split(a_y, factor=128) + a_ko, a_ki = t_sch[t_fc].split(a_k, factor=2) + + a_xo, a_xi = t_sch[t_fc].split(a_x, factor=128) + a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=32) + t_sch[t_fc].reorder(a_yo, a_xo, a_koo, a_xi, a_koi, a_yi, a_ki) + + t_sch[t_fc].tensorize(a_yi, pc) + # print(tvm.lower(t_sch, [X, packedW, t_fc], simple_mode=True)) + t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic") + t_evaluator = t_func.time_evaluator(t_func.entry_name, ctx, number=10) + + # generate the plain data + a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8") + b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8") + + packW = np.random.uniform(1, 10, size=(n/128, 128*(k/2), 2)).astype("int8") + # This occurs in pre_compute stage + for r_idx in range(n/128): + for s_idx in range(128*(k/2)): + for t_idx in range(2): + packW[r_idx][s_idx][t_idx] = b_[r_idx*128+s_idx%128][s_idx/128*2+t_idx] + + x = tvm.nd.array(a_, ctx) + w = tvm.nd.array(packW, ctx) + y = tvm.nd.array(np.zeros((m, n), dtype="int16"), ctx) + + result = t_evaluator(x, w, y) + gops_per_sec = gops_per_mm/result.mean/1e9 + tvm.testing.assert_allclose( + y.asnumpy(), np.dot(a_, b_.T), rtol=1e-5) + print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}.'.format(result.mean*1000, gops_per_sec, gops_per_sec/peak)) + t_func.export_library("gemm_tensorize.o") + + verify() + +if __name__ == "__main__": + benchmark_fc_int8_acc16() diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index 48fa75d81c9b..00681726257a 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -98,3 +98,82 @@ def _instr(index): with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + + +def dot_16x1x16_int8_int8_int16(): + """ + Int8 dot product by every 2 elements using AVX2 Skylake instructions. + This function takes two arrays of int8 datatype -- data[2] and + kernel[4][32][2] -- and computes a dot product of data[2] with every + 2 elements of kernels, resulting in output[4][32] of int16 datatype. + The pseudo code is as follows. + .. code-block:: c + void dot_16x1x16_int8_int8_int16(int8 data[2], int8 kernel[32*4][2], + int16 output[32*4]){ + for (int i = 0; i< 4; i++){ + for (int j = 0; j < 32; j++){ + out[i][i] = 0; + for (int k = 0; k < 2; k++){ + out[i][j][k] += data[k] * kernel[i][j][k] + } + } + } + } + Physically, the kernel array sits in four AVX512 vector registers and + the data[2] is broadcasted to another AVX512 vector register. This + function returns a TensorIntrin that can be used to tensorize + a schedule. + Returns + ------- + intrin : TensorIntrin + The Skylake int8 TensorIntrin that can be used in tensorizing schedule + """ + + num_int8_elements = 2 # 2 int8 elements in int32 + data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data') + kernel = tvm.placeholder((128, num_int8_elements), dtype='int8', name='kernel') + k = tvm.reduce_axis((0, num_int8_elements), name='k') + C = tvm.compute((128, ), + lambda i: tvm.sum(data[k].astype('int16') * + kernel[i, k].astype('int16'), + axis=k), + name="C") + + a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer", + offset_factor=1, + strides=[1]) + b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer", + offset_factor=1) + # strides=[tvm.var('ldw'), 1, 1]) + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.ir_builder.create() + if index == 1: + for i in range(4): + ib.emit(outs[0].vstore([i*32], tvm.const(0, 'int16x32'))) + return ib.get() + + a_int8 = ins[0].vload([0], "uint8x2") + re_int16 = tvm.call_pure_intrin('int16', 'reinterpret', a_int8) + vec_ai16 = re_int16.astype('int16x32') + vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai16) + + for i in range(4): + vec_b = ins[1].vload([i*32, 0], "int8x64") + pair_reduction = tvm.call_llvm_intrin('int16x32', + 'llvm.x86.avx512.pmaddubs.w.512', + tvm.const(0, 'uint32'), + vec_a, vec_b) + if index == 0: + ib.emit(outs[0].vstore([i*32], pair_reduction)) + else: + ib.emit(outs[0].vstore([i*32], pair_reduction + outs[0].vload([i*32], + 'int16x32'))) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + with tvm.build_config(offset_factor=1, partition_const_loop=True): + return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) From 30d026fc64af190bcb87c17be8c4622e449a5779 Mon Sep 17 00:00:00 2001 From: Mark Rogers Date: Wed, 15 May 2019 20:26:24 -0700 Subject: [PATCH 010/176] Get list of unsupported ONNX operators (#2995) --- nnvm/python/nnvm/frontend/onnx.py | 13 +++++++++++++ python/tvm/relay/frontend/onnx.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index eb78b7845c23..2434fb01c1d5 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -830,6 +830,19 @@ def from_onnx(self, graph, opset): else: self._num_input += 1 self._nodes[i_name] = _sym.Variable(name=i_name) + # get list of unsupported ops + convert_map = _get_convert_map(opset) + unsupported_ops = set() + for node in graph.node: + op_name = node.op_type + if op_name not in convert_map and \ + op_name != 'Constant' and \ + op_name not in _identity_list: + unsupported_ops.add(op_name) + if unsupported_ops: + msg = 'The following operators are not supported for frontend ONNX: ' + msg += ', '.join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 08a64c37d8df..470f4197c908 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -989,6 +989,19 @@ def from_onnx(self, graph, opset): else: dtype = d_type self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) + # get list of unsupported ops + convert_map = _get_convert_map(opset) + unsupported_ops = set() + for node in graph.node: + op_name = node.op_type + if op_name not in convert_map and \ + op_name != 'Constant' and \ + op_name not in _identity_list: + unsupported_ops.add(op_name) + if unsupported_ops: + msg = 'The following operators are not supported for frontend ONNX: ' + msg += ', '.join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type From 6bcc40d26449ae6d8fb078307c8df27985b761b7 Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 16 May 2019 09:25:38 +0530 Subject: [PATCH 011/176] [TENSORLFOW] PlaceholderWithDefault (limited) implementation. (#3184) --- python/tvm/relay/frontend/tensorflow.py | 6 +++--- .../frontend/tensorflow/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4bd78b47fe54..b5a9ea5781aa 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1740,7 +1740,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): for node in graph.node: node_name_prefix = node.name.rsplit('/', 1)[0] control_flow_node_map[node_name_prefix].add(node.op) - if node.op == 'Placeholder': + if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1800,7 +1800,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): attr = self._parse_attr(node.attr) - elif node.op != "Placeholder": + elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] @@ -1925,7 +1925,7 @@ def _parse_import_prerequisites(self, graph): """ missing_operators = set() for node in graph.node: - if node.op == "Placeholder": + if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 2f1cc2f6c9a4..90ee75823518 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1541,6 +1541,24 @@ def test_forward_reduce_prod(): _test_forward_reduce_prod((5, 5), 0, True) _test_forward_reduce_prod((5, 5), 1, True) + +####################################################################### +# PlaceholderWithDefault +# ---------------------- +def test_placeholder(): + with tf.Graph().as_default(): + in_data1 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) + var1 = tf.Variable(in_data1, name='in1') + var2 = array_ops.placeholder_with_default(var1, None, name='place1') + + in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) + place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2') + + out1 = tf.math.add(var1, var2, name='out1') + out2 = tf.math.add(out1, place1, name='out2') + + compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) + ####################################################################### # Main # ---- @@ -1590,6 +1608,7 @@ def test_forward_reduce_prod(): test_forward_multi_input() test_forward_multi_output() test_forward_variable() + test_placeholder() # NN test_forward_convolution() From 22a7af37740490975ebd4bee8fe1ae1a2dee75e6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 16 May 2019 23:39:13 +0800 Subject: [PATCH 012/176] [TOPI] Raise exception group_conv2d_nchw not supported (#3195) --- topi/python/topi/cuda/group_conv2d_nchw.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index be4ae3554e33..cbdb3cb8031d 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -346,6 +346,8 @@ def schedule_conv2d_nchw_cuda(cfg, outs): def _callback(op): if op.tag == "group_conv2d_NCHWc_int8": schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0)) + if op.tag == "group_conv2d_nchw": + raise tvm.error.OpNotImplemented("group_conv2d_nchw not supported") traverse_inline(s, outs[0].op, _callback) return s From a2e6d10f1dee3bc4093c2e481132b5f31c11bd40 Mon Sep 17 00:00:00 2001 From: Philipp Krones Date: Thu, 16 May 2019 20:00:38 +0200 Subject: [PATCH 013/176] Quick fix of VTA FPGA Toolchain Installation documentation (#3196) --- docs/vta/install.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/vta/install.md b/docs/vta/install.md index 233bb5ca0260..8fa779a5d5b8 100644 --- a/docs/vta/install.md +++ b/docs/vta/install.md @@ -208,7 +208,7 @@ chmod u+x Xilinx_Vivado_SDK_Web_2018.2_0614_1954_Lin64.bin #### Xilinx Vivado GUI Installer Steps -At this point you've launched the Vivado 2017.1 Installer GUI program. +At this point you've launched the Vivado 2018.2 Installer GUI program. 1. Click “Next” on the *Welcome* screen. 2. On the *Select Install Type* screen, enter your Xilinx user credentials under the “User Authentication” box and select the “Download and Install Now” option before clicking “Next” . From 57963d5b25441ccf4390f9848ab335c5b3c6922d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Thu, 16 May 2019 11:54:43 -0700 Subject: [PATCH 014/176] Update .gitignore (#3199) --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index e44aa6e21464..a7355739cf59 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,7 @@ patched.txt # Python type checking .mypy_cache/ .pyre/ + +# pipenv file +Pipfile +Pipfile.lock From fc00608f1a559a9b4d5fb3ebe13f334690ed348f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 16 May 2019 13:22:31 -0700 Subject: [PATCH 015/176] [RELAY] Hotfix build_module creation (#3198) --- src/relay/backend/build_module.cc | 61 ++++++++++++------------------- 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 63ee2d59d854..8a0c32fc6684 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -18,12 +18,11 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ - #include +#include #include #include #include @@ -40,31 +39,6 @@ namespace backend { using TargetsMap = Map; -/*! - * \brief Context index to Target - */ -struct ContextTargetMap { - static const std::unordered_map mask2str; - static tvm::Target Mask2Str(int mask) { - CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; - return mask2str.at(mask); - } -}; - -const std::unordered_map ContextTargetMap::mask2str = { - {1, tvm::Target::create("llvm")}, - {2, tvm::Target::create("cuda")}, - {4, tvm::Target::create("opencl")}, - {5, tvm::Target::create("aocl")}, - {6, tvm::Target::create("sdaccel")}, - {7, tvm::Target::create("vulkan")}, - {8, tvm::Target::create("metal")}, - {9, tvm::Target::create("vpi")}, - {10, tvm::Target::create("rocm")}, - {11, tvm::Target::create("opengl")}, - {12, tvm::Target::create("ext_dev")} -}; - /*! * \brief A data structure to map the names of specific optimizations to * numeric optimization levels @@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return Array names of params */ - Array ListParamNames() { - Array ret; + Array ListParamNames() { + Array ret; for (const auto& kv : params_) { ret.push_back(ir::StringImm::make(kv.first)); } @@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode { if (cfg.pass_enabled("AlterOpLayout")) { if (targets.size() == 1) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - auto enter_pf = GetPackedFunc("_EnterTargetScope"); - auto exit_pf = GetPackedFunc("_ExitTargetScope"); for (const auto& kv : targets) { - (*enter_pf)(kv.second); + TargetContext tctx(kv.second); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); - (*exit_pf)(); } } else { LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" @@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode { } return func; } + + /*! + * \brief Create a default type. + * \param device_type The device type index. + * \return the default target for the device. + */ + Target CreateDefaultTarget(int device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") return Target::create("llvm"); + if (name == "gpu") return Target::create("cuda"); + return Target::create(name); + } /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. @@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (tmp_map.count(cfg.fallback_device) == 0) { device_target.Set( cfg.fallback_device, - ContextTargetMap::Mask2Str(cfg.fallback_device)); + CreateDefaultTarget(cfg.fallback_device)); } return device_target; } @@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param targets_map_ptr * \return Function */ - Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg, + Function RunDeviceAnnotationPass(Function func, + const RelayBuildConfig& cfg, TargetsMap* targets_map_ptr) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, @@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode { "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); if (annotation_map.size() == 0) { targets_map_ptr->Set( - 0, ContextTargetMap::Mask2Str(cfg.fallback_device)); + 0, CreateDefaultTarget(cfg.fallback_device)); } else { int64_t dev_type = -1; for (auto kv : annotation_map) { @@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type)); + targets_map_ptr->Set(0, CreateDefaultTarget(dev_type)); } } return func; @@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") +.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); }); From d9a379832356d616f6abd8dd03f1c6f94a3e8e5e Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 17 May 2019 03:41:50 -0700 Subject: [PATCH 016/176] [Relay] Better shape inference in TensorFlow Frontend. (#3176) * Some bug fixes in tensorflow graph converter and added DepthToSpace operator. * Made DepthToSpace better comply with other function syntax. * Added better shape inference for unusual situations. * Lint fixes. * Added depthtospace test. * Added test cases for value inference and depthtospace. * Added fill testing. * Made comment changes and added BroadcastTo op and tests. * Fixed underlining and unneeded opt_level forcing. * Added _infer_value assertion that all values to infer are available in passed parameters. --- python/tvm/relay/frontend/tensorflow.py | 100 +++++++++++-- .../frontend/tensorflow/test_forward.py | 131 ++++++++++++++++-- 2 files changed, 210 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b5a9ea5781aa..11026b9e5ad8 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -34,6 +34,20 @@ __all__ = ['from_tensorflow'] +def _infer_value(input_val, params): + from tvm.contrib import graph_runtime + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in ir_pass.free_vars( + input_val)), "All inputs to infer must be available in params." + func = _expr.Function(ir_pass.free_vars(input_val), input_val) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + return m.get_output(0) + def _get_relay_op(op_name): try: op = getattr(_op, op_name) @@ -465,7 +479,12 @@ def _impl(inputs, attr, params): def _resize_bilinear(): def _impl(inputs, attr, params): - attr['size'] = attr['_output_shapes'][0][1:3] + size = attr['_output_shapes'][0][1:3] + # Important that the size is defined. If an axis is not, we need to infer what + # the shape should be. + if -1 in size: + size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + attr['size'] = size inputs.pop(1) # NHWC attr['layout'] = 'NHWC' @@ -574,15 +593,7 @@ def _impl(inputs, attr, params): except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) + params_new = _infer_value(inputs[1], params) inputs.pop(1) return AttrCvt( op_name="reshape", @@ -590,9 +601,63 @@ def _impl(inputs, attr, params): ignores=['Tshape'])(inputs, attr) return _impl + +def _depth_to_space(): + def _impl(inputs, attr, params): + # Need to handle data layouts differently. + input_shape = attr['_input_shapes'][inputs[0]] + block_size = int(attr['block_size']) + if attr['data_format'].decode("utf-8") == 'NHWC': + in_n, in_h, in_w, in_c = input_shape + new_c = int(in_c / (block_size * block_size)) + + # First expand input to larger dimension. + expanded = _op.reshape( + inputs[0], newshape=(in_n, in_h, in_w, block_size, block_size, new_c)) + # Now reorder to expand spatial blocks. + transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5)) + # Finally reshape to proper output. + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_h, new_w, new_c) + + else: # Handle NCHW layout + in_n, in_c, in_h, in_w = input_shape + new_c = int(in_c / (block_size * block_size)) + + expanded = _op.reshape( + inputs[0], newshape=(in_n, block_size, block_size, new_c, in_h, in_w)) + transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2)) + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_c, new_h, new_w) + + return AttrCvt( + op_name="reshape", + extras={'newshape': newshape}, + ignores=['data_format', 'block_size'])([transposed], attr) + + return _impl + + def _bias_add(): def _impl(inputs, attr, params): - return _op.add(inputs[0], inputs[1]) + # Must expand for proper broadcasting in NCHW. + if attr['data_format'].decode("utf-8") == 'NCHW': + bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) + else: + bias = inputs[1] + return _op.add(inputs[0], bias) + return _impl + +def _broadcast_to(): + def _impl(inputs, attr, params): + if isinstance(inputs[1], _expr.Var): + shape = params[inputs[1].name_hint] + else: + shape = _infer_value(inputs[1], params) + shape = list(shape.asnumpy().reshape([-1])) + return _op.broadcast_to(inputs[0], shape) return _impl def _squeeze(): @@ -666,9 +731,15 @@ def _impl(inputs, attr, params): def _fill(): def _impl(inputs, attr, params): + output_shape = attr['_output_shapes'][0] + # Output shape must be defined to avoid errors. If any axis is not, we must + # try to compute its shape. + if -1 in output_shape: + output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() + fill_arg = params.pop(inputs.pop(1).name_hint) return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), - attr['_output_shapes'][0], attr['T'].name) + output_shape, attr['T'].name) return _impl def _lrn(): @@ -1115,6 +1186,7 @@ def _impl(inputs, attr, params): 'BatchNormWithGlobalNormalization' : _batch_norm(), 'BatchToSpaceND' : _batch_to_space_nd(), 'BiasAdd' : _bias_add(), + 'BroadcastTo' : _broadcast_to(), 'Cast' : _cast(), 'Ceil' : AttrCvt('ceil'), 'CheckNumerics' : _check_numerics(), @@ -1123,6 +1195,7 @@ def _impl(inputs, attr, params): 'Conv2D' : _conv('conv'), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), + 'DepthToSpace' : _depth_to_space(), 'Equal' : _broadcast('equal'), 'Elu' : _elu(), 'Exp' : AttrCvt('exp'), @@ -1158,11 +1231,12 @@ def _impl(inputs, attr, params): 'Prod' : _prod(), 'Range' : _range(), 'Rank' : _rank(), - 'RealDiv' : _elemwise('div'), + 'RealDiv' : _elemwise('divide'), 'Relu' : AttrCvt('relu'), 'Relu6' : _relu6(), 'Reshape' : _reshape(), 'ResizeBilinear' : _resize_bilinear(), + 'ResizeBicubic' : _resize_bilinear(), 'ReverseV2' : _reverse_v2(), 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 90ee75823518..e4626e0d60ff 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -47,7 +47,8 @@ def convert_to_list(x): x = [x] return x -def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None): +def run_tvm_graph(graph_def, input_data, input_node, num_output=1, + target='llvm', out_names=None, opt_level=3): """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) @@ -71,7 +72,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' layout=layout, shape=shape_dict, outputs=out_names) - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=opt_level): graph, lib, params = relay.build(sym, target, params=params) ctx = tvm.context(target, 0) @@ -85,8 +86,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' # execute m.run() # get outputs - assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format( - out_names, num_output) + assert out_names is None or num_output == len(out_names), ( + "out_names: {} num_output: {}".format(out_names, num_output)) tvm_output_list = [] for i in range(0, num_output): tvm_output = m.get_output(i) @@ -111,7 +112,8 @@ def run_tf_graph(sess, input_data, input_node, output_node): return output_data -def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False): +def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, + no_gpu=False, opt_level=3): """Generic function to generate and compare tensorflow and TVM output""" out_name = convert_to_list(out_name) @@ -142,8 +144,9 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, if no_gpu and device == 'cuda': continue - tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device, - out_names=out_name, num_output=len(out_name)) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, + target=device, out_names=out_name, + num_output=len(out_name), opt_level=opt_level) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared for i in range(len(tf_output)): @@ -411,6 +414,23 @@ def test_forward_reshape(): _test_reshape(np.arange(6), [-1]) ####################################################################### +# DepthToSpace +# ------------ + +def _test_depthtospace(data, block_size): + """ One iteration of depth_to_space operation with given data and block size """ + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + array_ops.depth_to_space(in_data, block_size) + + compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0') + +def test_forward_depthtospace(): + _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2) + _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4) + + ####################################################################### # Squeeze # ------- @@ -840,16 +860,108 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners): with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + shape_data = constant_op.constant( + shape_data, shape=shape_data.shape, dtype=shape_data.dtype) tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners) compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') +def _test_resize_bilinear_from_tensor(in_shape, align_corners): + """ One iteration of resize bilinear with non-constant output shape, requires + value inference to get proper output shape.""" + + data = np.random.uniform(size=in_shape).astype('float32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder( + shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype) + to_shape = tf.shape(in_data)[2:] + tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners) + + compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') + def test_forward_resize_bilinear(): """ Resize Bilinear """ _test_resize_bilinear((4, 16, 32, 32), [50, 50], False) _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) + _test_resize_bilinear_from_tensor((4, 16, 32, 32), False) + _test_resize_bilinear_from_tensor((6, 32, 50, 50), True) + +####################################################################### +# BroadcastTo +# ----------- + +def _test_broadcast_to(in_shape, to_shape): + """ One iteration of broadcast_to""" + + data = np.random.uniform(size=in_shape).astype('float32') + shape_data = np.array(to_shape).astype('int32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + shape_data = constant_op.constant( + shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + tf.broadcast_to(in_data, shape_data) + + compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0) + + +def _test_broadcast_to_from_tensor(in_shape): + """ One iteration of broadcast_to with unknown shape at graph build""" + + data = np.random.uniform(size=in_shape).astype('float32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder( + shape=[None], dtype=data.dtype) + + shape_data = tf.multiply(tf.shape(in_data), 32) + tf.broadcast_to(in_data, shape_data) + + compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0') + + +def test_forward_broadcast_to(): + """ Resize Bilinear """ + + _test_broadcast_to((4, 1, 32, 32), [4, 8, 32, 32]) + _test_broadcast_to((6, 32, 32, 1), [6, 32, 32, 16]) + _test_broadcast_to_from_tensor((1)) + + +####################################################################### +# Fill +# ---- + +def _test_fill(in_shape): + """ Use the fill op to create a tensor of ones with non-constant shape.""" + + with tf.Graph().as_default(): + tf.ones(shape=in_shape, dtype='float32') + compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1) + +def _test_fill_from_tensor(in_shape): + """ Use the fill op to create a tensor of ones with non-constant shape. + Some extra ops need to be added here to prevent the graph from + being fully constant and folded away.""" + + data = np.random.uniform(size=in_shape).astype('float32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder( + shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype) + + x = tf.ones(shape=2*tf.shape(in_data), dtype=data.dtype) + y = tf.math.add(in_data, tf.reduce_mean(x), name='out1') + compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0') + +def test_forward_fill(): + """ Resize Bilinear """ + + _test_fill((32)) + _test_fill((6, 32, 64, 64)) + _test_fill_from_tensor((6, 32, 64, 64)) ####################################################################### # Crop to bounding box @@ -1567,9 +1679,12 @@ def test_placeholder(): # Transforms test_forward_transpose() test_forward_reshape() + test_forward_depthtospace() test_forward_squeeze() test_forward_pack() test_forward_resize_bilinear() + test_forward_broadcast_to() + test_forward_fill() test_forward_crop() test_forward_pad() test_forward_gather() From d98318caa30d12cec6a011447bd0eb8713bf3ebc Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Sat, 18 May 2019 01:13:17 +0800 Subject: [PATCH 017/176] [CODEGEN][CUDA][OPENCL] Handle INF and NAN (#3194) --- src/codegen/codegen_cuda.cc | 19 +++++++++++-- src/codegen/codegen_cuda.h | 7 ++++- src/codegen/codegen_opencl.cc | 13 +++++++++ src/codegen/codegen_opencl.h | 1 + tests/python/unittest/test_codegen_cuda.py | 30 ++++++++++++++++++++ tests/python/unittest/test_codegen_opencl.py | 27 ++++++++++++++++++ 6 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index ef92f9ae3175..22dde1c46389 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -57,6 +57,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_math_constants_h_) { + decl_stream << "#include \n"; + } + return CodeGenC::Finish(); } @@ -318,8 +322,19 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { / switch (op->type.bits()) { case 64: case 32: { std::ostringstream temp; - temp << std::scientific << op->value; - if (op->type.bits() == 32) temp << 'f'; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + p->need_math_constants_h_ = true; + } else if (std::isnan(op->value)) { + temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + p->need_math_constants_h_ = true; + } else { + temp << std::scientific << op->value; + if (op->type.bits() == 32) temp << 'f'; + } p->MarkConst(temp.str()); os << temp.str(); break; diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 381784a13a57..acd759f33889 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -39,7 +39,9 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); void AddFunction(LoweredFunc f); std::string Finish(); - bool need_include_path() { return (enable_fp16_ || enable_int8_); } + bool need_include_path() { + return (enable_fp16_ || enable_int8_ || need_math_constants_h_); + } // override behavior void VisitStmt_(const ir::For* op) final; void PrintStorageSync(const Call* op) final; @@ -70,6 +72,9 @@ class CodeGenCUDA final : public CodeGenC { bool enable_fp16_{false}; // whether enable int8 bool enable_int8_{false}; + // whether need math_constants.h + bool need_math_constants_h_{false}; + friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 382124a7ed2d..0b33bf43c151 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -247,6 +247,19 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT( CodeGenC::VisitExpr_(op, os); } +void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) + if (std::isinf(op->value)) { + if (op->value < 0) { + os << "-"; + } + os << "INFINITY"; + } else if (std::isnan(op->value)) { + os << "NAN"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + runtime::Module BuildOpenCL(Array funcs) { using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 0eff3a633ba3..36a55a545cbd 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -59,6 +59,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index f28b4ccfd1da..8fe6720830a5 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -125,8 +125,38 @@ def check_cuda(n, value): check_cuda(64, 0) check_cuda(64, -3) + +def test_cuda_inf_nan(): + target = 'cuda' + def check_inf_nan(ctx, n, value, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + inf_value = tvm.const(value, dtype=dtype) + C = tvm.compute((n,), lambda i: inf_value, name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + + ctx = tvm.context(target, 0) + + check_inf_nan(ctx, 1, -float('inf'), 'float32') + check_inf_nan(ctx, 1, -float('inf'), 'float64') + check_inf_nan(ctx, 1, float('inf'), 'float32') + check_inf_nan(ctx, 1, float('inf'), 'float64') + check_inf_nan(ctx, 1, float('nan'), 'float32') + check_inf_nan(ctx, 1, float('nan'), 'float64') + + if __name__ == "__main__": test_cuda_vectorize_add() test_cuda_multiply_add() test_cuda_vectorize_load() test_cuda_make_int8x4() + test_cuda_inf_nan() diff --git a/tests/python/unittest/test_codegen_opencl.py b/tests/python/unittest/test_codegen_opencl.py index c484664bdfd8..71fc4f9a7f35 100644 --- a/tests/python/unittest/test_codegen_opencl.py +++ b/tests/python/unittest/test_codegen_opencl.py @@ -66,6 +66,33 @@ def check_select(ctx, n, dtype): check_select(ctx, 1, 'int16') check_select(ctx, 1, 'uint16') +def test_opencl_inf_nan(): + def check_inf_nan(ctx, n, value, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + inf_value = tvm.const(value, dtype=dtype) + C = tvm.compute((n,), lambda i: inf_value, name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + if not tvm.module.enabled(target): + print("skip because opencl is not enabled..") + return + + ctx = tvm.context(target, 0) + + check_inf_nan(ctx, 1, -float('inf'), 'float32') + check_inf_nan(ctx, 1, -float('inf'), 'float64') + check_inf_nan(ctx, 1, float('inf'), 'float32') + check_inf_nan(ctx, 1, float('inf'), 'float64') + check_inf_nan(ctx, 1, float('nan'), 'float32') + check_inf_nan(ctx, 1, float('nan'), 'float64') + if __name__ == "__main__": test_opencl_ternary_expression() + test_opencl_inf_nan() From 4eb19ad78ce465fb5ba5c88406ec076719665a27 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Fri, 17 May 2019 10:29:07 -0700 Subject: [PATCH 018/176] [ARM] Fix concat (#3061) --- python/tvm/relay/op/_transform.py | 3 ++- python/tvm/relay/op/op.py | 7 ++++++ topi/python/topi/arm_cpu/injective.py | 29 ++++++++++++++++++++++++ topi/tests/python/test_topi_transform.py | 3 ++- 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 2eec6d03e7cd..95fb2ad18a25 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -23,6 +23,7 @@ schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective +schedule_concatenate = _reg.schedule_concatenate _reg.register_schedule("collapse_sum_like", _schedule_reduce) @@ -46,7 +47,7 @@ _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("stack", schedule_injective) -_reg.register_schedule("concatenate", schedule_injective) +_reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 6ba207934d1b..906bf255d46e 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target): with target: return topi.generic.schedule_injective(outputs) + +def schedule_concatenate(attrs, outputs, target): + """Generic schedule for concatinate.""" + with target: + return topi.generic.schedule_concatenate(outputs) + + __DEBUG_COUNTER__ = 0 def debug(expr, debug_func=None): diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 9afdc32cf117..028558f69e91 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -51,3 +51,32 @@ def schedule_injective(outs): elif len(s[x].op.axis) >= 2: s[x].parallel(s[x].op.axis[0]) return s + +@generic.schedule_concatenate.register(["arm_cpu"]) +def schedule_concatenate(outs): + """Schedule for concatenate op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of reduce in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + x = outs[0] + tvm.schedule.AutoInlineInjective(s) + if len(s[x].op.axis) >= 4: + fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2]) + s[x].parallel(fused) + elif len(s[x].op.axis) >= 3: + fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) + s[x].parallel(fused) + elif len(s[x].op.axis) >= 2: + s[x].parallel(s[x].op.axis[0]) + return s diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index a078eacae85b..d29fb64544b9 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -127,7 +127,7 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - s = topi.generic.schedule_injective(out_tensor) + s = topi.generic.schedule_concatenate(out_tensor) foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] @@ -476,6 +476,7 @@ def test_concatenate(): (12, 6, 7, 3), (8, 6, 7, 3), (2, 6, 7, 3)], 0) + verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1) def test_stack(): From dafa2f8c5c84be97ac24abc581016705e66a1f1c Mon Sep 17 00:00:00 2001 From: Hua Date: Mon, 20 May 2019 10:01:49 -0700 Subject: [PATCH 019/176] [BugFix][VTA] Fix bug in vta runtime DepPop function. (#3208) Issue: One of existing illegal dependency check's condition always true, the correct logic actually should be such check for store and load. Solution: Fix the said logic issue. --- vta/src/runtime.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index 9c8de1aaae5e..7af0de1a8f8b 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -514,7 +514,7 @@ class InsnQueue : public BaseQueue { } // Impossible condition CHECK(from != kLoadStage || to != kStoreStage); - CHECK(to != kLoadStage || to != kComputeStage); + CHECK(from != kStoreStage || to != kLoadStage); } // Insert dependency push of load void DepPush(int from, int to) { From 2039d2a0afcfff33e38c39992b672ab73a2727f6 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Mon, 20 May 2019 10:05:36 -0700 Subject: [PATCH 020/176] [VTA] [TSIM] Improve tsim example (#3206) --- vta/apps/tsim_example/CMakeLists.txt | 4 +- vta/apps/tsim_example/Makefile | 13 +--- vta/apps/tsim_example/README.md | 10 +-- .../cmake/modules/{tsim.cmake => hw.cmake} | 62 +++++++++---------- .../cmake/modules/{driver.cmake => sw.cmake} | 6 +- .../{python/tsim => config}/config.json | 2 +- .../{python/tsim => config}/config.py | 0 .../python/tsim/{load.py => driver.py} | 54 +++++++--------- vta/apps/tsim_example/src/driver.cc | 7 +-- .../python/{test_tsim.py => add_by_one.py} | 11 ++-- 10 files changed, 71 insertions(+), 98 deletions(-) rename vta/apps/tsim_example/cmake/modules/{tsim.cmake => hw.cmake} (67%) rename vta/apps/tsim_example/cmake/modules/{driver.cmake => sw.cmake} (81%) rename vta/apps/tsim_example/{python/tsim => config}/config.json (82%) rename vta/apps/tsim_example/{python/tsim => config}/config.py (100%) rename vta/apps/tsim_example/python/tsim/{load.py => driver.py} (51%) rename vta/apps/tsim_example/tests/python/{test_tsim.py => add_by_one.py} (90%) diff --git a/vta/apps/tsim_example/CMakeLists.txt b/vta/apps/tsim_example/CMakeLists.txt index 4163c88ce3b8..28cfded75823 100644 --- a/vta/apps/tsim_example/CMakeLists.txt +++ b/vta/apps/tsim_example/CMakeLists.txt @@ -35,5 +35,5 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND endif() # Module rules -include(cmake/modules/tsim.cmake) -include(cmake/modules/driver.cmake) +include(cmake/modules/hw.cmake) +include(cmake/modules/sw.cmake) diff --git a/vta/apps/tsim_example/Makefile b/vta/apps/tsim_example/Makefile index e4911ceda419..2d7629ce12b2 100644 --- a/vta/apps/tsim_example/Makefile +++ b/vta/apps/tsim_example/Makefile @@ -17,16 +17,7 @@ export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH) -BUILD_DIR = $(shell python python/tsim/config.py --get-build-name) - -TVM_DIR = $(abspath ../../../) - -TSIM_TARGET = verilog -TSIM_TOP_NAME = TestAccel -TSIM_BUILD_NAME = build - -# optional -TSIM_TRACE_NAME = trace.vcd +BUILD_DIR = $(shell python3 config/config.py --get-build-name) default: cmake run @@ -39,7 +30,7 @@ $(BUILD_DIR): mkdir -p $@ run: - python3 tests/python/test_tsim.py | grep PASS + python3 tests/python/add_by_one.py | grep PASS clean: -rm -rf $(BUILD_DIR) diff --git a/vta/apps/tsim_example/README.md b/vta/apps/tsim_example/README.md index 4cde4242dc28..b557b24ac690 100644 --- a/vta/apps/tsim_example/README.md +++ b/vta/apps/tsim_example/README.md @@ -64,8 +64,8 @@ These examples are located at `/vta/apps/tsim_example`. * Run `make` * Some pointers - * Build cmake script for driver `/vta/apps/tsim_example/cmake/modules/driver.cmake` - * Build cmake script for tsim `/vta/apps/tsim_example/cmake/modules/tsim.cmake` - * Software driver that handles the VTA accelerator `/vta/apps/tsim_example/src/driver.cc` - * VTA add-by-one accelerator (Verilog) `/vta/apps/tsim_example/hardware/verilog` - * VTA add-by-one accelerator (Chisel) `/vta/apps/tsim_example/hardware/chisel` + * Build cmake script for software library`/vta/apps/tsim_example/cmake/modules/sw.cmake` + * Build cmake script for hardware library`/vta/apps/tsim_example/cmake/modules/hw.cmake` + * Software driver that handles the accelerator `/vta/apps/tsim_example/src/driver.cc` + * Add-by-one accelerator in Verilog `/vta/apps/tsim_example/hardware/verilog` + * Add-by-one accelerator in Chisel3 `/vta/apps/tsim_example/hardware/chisel` diff --git a/vta/apps/tsim_example/cmake/modules/tsim.cmake b/vta/apps/tsim_example/cmake/modules/hw.cmake similarity index 67% rename from vta/apps/tsim_example/cmake/modules/tsim.cmake rename to vta/apps/tsim_example/cmake/modules/hw.cmake index 4c81f288e45a..87dd72b2e626 100644 --- a/vta/apps/tsim_example/cmake/modules/tsim.cmake +++ b/vta/apps/tsim_example/cmake/modules/hw.cmake @@ -16,7 +16,7 @@ # under the License. if(MSVC) - message(STATUS "TSIM build is skipped in Windows..") + message(STATUS "[TSIM_HW] build is skipped in Windows..") else() find_program(PYTHON NAMES python python3 python3.6) find_program(VERILATOR NAMES verilator) @@ -24,26 +24,20 @@ else() if (VERILATOR AND PYTHON) if (TSIM_TOP_NAME STREQUAL "") - message(FATAL_ERROR "TSIM_TOP_NAME should be defined") + message(FATAL_ERROR "[TSIM_HW] TSIM_TOP_NAME should be defined") endif() if (TSIM_BUILD_NAME STREQUAL "") - message(FATAL_ERROR "TSIM_BUILD_NAME should be defined") + message(FATAL_ERROR "[TSIM_HW] TSIM_BUILD_NAME should be defined") endif() - set(TSIM_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/python/tsim/config.py) + set(TSIM_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/config/config.py) - execute_process(COMMAND ${TSIM_CONFIG} --get-target OUTPUT_VARIABLE __TSIM_TARGET) - execute_process(COMMAND ${TSIM_CONFIG} --get-top-name OUTPUT_VARIABLE __TSIM_TOP_NAME) - execute_process(COMMAND ${TSIM_CONFIG} --get-build-name OUTPUT_VARIABLE __TSIM_BUILD_NAME) - execute_process(COMMAND ${TSIM_CONFIG} --get-use-trace OUTPUT_VARIABLE __TSIM_USE_TRACE) - execute_process(COMMAND ${TSIM_CONFIG} --get-trace-name OUTPUT_VARIABLE __TSIM_TRACE_NAME) - - string(STRIP ${__TSIM_TARGET} TSIM_TARGET) - string(STRIP ${__TSIM_TOP_NAME} TSIM_TOP_NAME) - string(STRIP ${__TSIM_BUILD_NAME} TSIM_BUILD_NAME) - string(STRIP ${__TSIM_USE_TRACE} TSIM_USE_TRACE) - string(STRIP ${__TSIM_TRACE_NAME} TSIM_TRACE_NAME) + execute_process(COMMAND ${TSIM_CONFIG} --get-target OUTPUT_VARIABLE TSIM_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND ${TSIM_CONFIG} --get-top-name OUTPUT_VARIABLE TSIM_TOP_NAME OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND ${TSIM_CONFIG} --get-build-name OUTPUT_VARIABLE TSIM_BUILD_NAME OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND ${TSIM_CONFIG} --get-use-trace OUTPUT_VARIABLE TSIM_USE_TRACE OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND ${TSIM_CONFIG} --get-trace-name OUTPUT_VARIABLE TSIM_TRACE_NAME OUTPUT_STRIP_TRAILING_WHITESPACE) set(TSIM_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/${TSIM_BUILD_NAME}) @@ -60,24 +54,24 @@ else() COMMAND ${SBT} publishLocal RESULT_VARIABLE RETCODE) if (NOT RETCODE STREQUAL "0") - message(FATAL_ERROR "[TSIM] sbt failed to install VTA scala package") + message(FATAL_ERROR "[TSIM_HW] sbt failed to install VTA scala package") endif() # Chisel - Scala to Verilog compilation set(TSIM_CHISEL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/chisel) - set(CHISEL_TARGET_DIR ${TSIM_BUILD_DIR}/chisel) - set(CHISEL_OPT "test:runMain test.Elaborate --target-dir ${CHISEL_TARGET_DIR} --top-name ${TSIM_TOP_NAME}") + set(CHISEL_BUILD_DIR ${TSIM_BUILD_DIR}/chisel) + set(CHISEL_OPT "test:runMain test.Elaborate --target-dir ${CHISEL_BUILD_DIR} --top-name ${TSIM_TOP_NAME}") execute_process(WORKING_DIRECTORY ${TSIM_CHISEL_DIR} COMMAND ${SBT} ${CHISEL_OPT} RESULT_VARIABLE RETCODE) if (NOT RETCODE STREQUAL "0") - message(FATAL_ERROR "[TSIM] sbt failed to compile from Chisel to Verilog.") + message(FATAL_ERROR "[TSIM_HW] sbt failed to compile from Chisel to Verilog.") endif() - file(GLOB VERILATOR_RTL_SRC ${CHISEL_TARGET_DIR}/*.v) + file(GLOB VERILATOR_RTL_SRC ${CHISEL_BUILD_DIR}/*.v) else() - message(FATAL_ERROR "[TSIM] sbt should be installed for Chisel") + message(FATAL_ERROR "[TSIM_HW] sbt should be installed for Chisel") endif() # sbt elseif (TSIM_TARGET STREQUAL "verilog") @@ -87,24 +81,24 @@ else() file(GLOB VERILATOR_RTL_SRC ${VTA_VERILOG_DIR}/*.v ${TSIM_VERILOG_DIR}/*.v) else() - message(STATUS "[TSIM] target language can be only verilog or chisel...") + message(FATAL_ERROR "[TSIM_HW] target language can be only verilog or chisel...") endif() # TSIM_TARGET if (TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog") # Check if tracing can be enabled if (NOT TSIM_USE_TRACE STREQUAL "OFF") - message(STATUS "[TSIM] Verilog enable tracing") + message(STATUS "[TSIM_HW] Verilog enable tracing") else() - message(STATUS "[TSIM] Verilator disable tracing") + message(STATUS "[TSIM_HW] Verilator disable tracing") endif() # Verilator - Verilog to C++ compilation - set(VERILATOR_TARGET_DIR ${TSIM_BUILD_DIR}/verilator) + set(VERILATOR_BUILD_DIR ${TSIM_BUILD_DIR}/verilator) set(VERILATOR_OPT +define+RANDOMIZE_GARBAGE_ASSIGN +define+RANDOMIZE_REG_INIT) list(APPEND VERILATOR_OPT +define+RANDOMIZE_MEM_INIT --x-assign unique) list(APPEND VERILATOR_OPT --output-split 20000 --output-split-cfuncs 20000) - list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_TARGET_DIR}) + list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_BUILD_DIR}) list(APPEND VERILATOR_OPT --cc ${VERILATOR_RTL_SRC}) if (NOT TSIM_USE_TRACE STREQUAL "OFF") @@ -114,7 +108,7 @@ else() execute_process(COMMAND ${VERILATOR} ${VERILATOR_OPT} RESULT_VARIABLE RETCODE) if (NOT RETCODE STREQUAL "0") - message(FATAL_ERROR "[TSIM] Verilator failed to compile Verilog to C++...") + message(FATAL_ERROR "[TSIM_HW] Verilator failed to compile Verilog to C++...") endif() # Build shared library (.so) @@ -126,9 +120,9 @@ else() list(APPEND VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated_vcd_c.cpp) endif() - file(GLOB VERILATOR_GEN_SRC ${VERILATOR_TARGET_DIR}/*.cpp) + file(GLOB VERILATOR_GEN_SRC ${VERILATOR_BUILD_DIR}/*.cpp) file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc) - add_library(tsim SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC}) + add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC}) set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) if (NOT TSIM_USE_TRACE STREQUAL "OFF") @@ -136,17 +130,17 @@ else() else() list(APPEND VERILATOR_DEF VM_TRACE=0) endif() - target_compile_definitions(tsim PRIVATE ${VERILATOR_DEF}) - target_compile_options(tsim PRIVATE -Wno-sign-compare -include V${TSIM_TOP_NAME}.h) - target_include_directories(tsim PRIVATE ${VERILATOR_TARGET_DIR} ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd ${VTA_DIR}/include) + target_compile_definitions(hw PRIVATE ${VERILATOR_DEF}) + target_compile_options(hw PRIVATE -Wno-sign-compare -include V${TSIM_TOP_NAME}.h) + target_include_directories(hw PRIVATE ${VERILATOR_BUILD_DIR} ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd ${VTA_DIR}/include) if(APPLE) - set_target_properties(tsim PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(hw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif(APPLE) endif() # TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog" else() - message(STATUS "[TSIM] could not find Python or Verilator, build is skipped...") + message(STATUS "[TSIM_HW] could not find Python or Verilator, build is skipped...") endif() # VERILATOR endif() # MSVC diff --git a/vta/apps/tsim_example/cmake/modules/driver.cmake b/vta/apps/tsim_example/cmake/modules/sw.cmake similarity index 81% rename from vta/apps/tsim_example/cmake/modules/driver.cmake rename to vta/apps/tsim_example/cmake/modules/sw.cmake index c4c80637918f..d0368c3edc75 100644 --- a/vta/apps/tsim_example/cmake/modules/driver.cmake +++ b/vta/apps/tsim_example/cmake/modules/sw.cmake @@ -16,9 +16,9 @@ # under the License. file(GLOB TSIM_SW_SRC src/driver.cc) -add_library(driver SHARED ${TSIM_SW_SRC}) -target_include_directories(driver PRIVATE ${VTA_DIR}/include) +add_library(sw SHARED ${TSIM_SW_SRC}) +target_include_directories(sw PRIVATE ${VTA_DIR}/include) if(APPLE) - set_target_properties(driver PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif(APPLE) diff --git a/vta/apps/tsim_example/python/tsim/config.json b/vta/apps/tsim_example/config/config.json similarity index 82% rename from vta/apps/tsim_example/python/tsim/config.json rename to vta/apps/tsim_example/config/config.json index 887eaac67d74..5f9ee69904fd 100644 --- a/vta/apps/tsim_example/python/tsim/config.json +++ b/vta/apps/tsim_example/config/config.json @@ -2,6 +2,6 @@ "TARGET" : "verilog", "TOP_NAME" : "TestAccel", "BUILD_NAME" : "build", - "USE_TRACE" : "off", + "USE_TRACE" : "OFF", "TRACE_NAME" : "trace" } diff --git a/vta/apps/tsim_example/python/tsim/config.py b/vta/apps/tsim_example/config/config.py similarity index 100% rename from vta/apps/tsim_example/python/tsim/config.py rename to vta/apps/tsim_example/config/config.py diff --git a/vta/apps/tsim_example/python/tsim/load.py b/vta/apps/tsim_example/python/tsim/driver.py similarity index 51% rename from vta/apps/tsim_example/python/tsim/load.py rename to vta/apps/tsim_example/python/tsim/driver.py index ef94fa97d206..997d9d527bfe 100644 --- a/vta/apps/tsim_example/python/tsim/load.py +++ b/vta/apps/tsim_example/python/tsim/driver.py @@ -21,36 +21,24 @@ import os.path as osp from sys import platform -def get_build_path(): - curr_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) - cfg = json.load(open(osp.join(curr_path, 'config.json'))) - return osp.join(curr_path, "..", "..", cfg['BUILD_NAME']) - -def get_lib_ext(): - if platform == "darwin": - ext = ".dylib" - else: - ext = ".so" - return ext - -def get_lib_path(name): - build_path = get_build_path() - ext = get_lib_ext() - libname = name + ext - return osp.join(build_path, libname) - -def _load_driver_lib(): - lib = get_lib_path("libdriver") - try: - return [ctypes.CDLL(lib, ctypes.RTLD_GLOBAL)] - except OSError: - return [] - -def load_driver(): - return tvm.get_global_func("tvm.vta.driver") - -def load_tsim(): - lib = get_lib_path("libtsim") - return tvm.module.load(lib, "vta-tsim") - -LIBS = _load_driver_lib() +def driver(hw, sw): + _cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) + _root_path = osp.join(_cur_path, "..", "..") + _cfg_file = osp.join(_root_path, "config", "config.json") + _cfg = json.load(open(_cfg_file)) + _ext = ".dylib" if platform == "darwin" else ".so" + _hw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], hw + _ext) + _sw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], sw + _ext) + + def load_dll(dll): + try: + return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)] + except OSError: + return [] + + def run(a, b): + load_dll(_sw_lib) + f = tvm.get_global_func("tvm.vta.driver") + m = tvm.module.load(_hw_lib, "vta-tsim") + f(m, a, b) + return run diff --git a/vta/apps/tsim_example/src/driver.cc b/vta/apps/tsim_example/src/driver.cc index 9898537a3f25..c11a8f8a3ee7 100644 --- a/vta/apps/tsim_example/src/driver.cc +++ b/vta/apps/tsim_example/src/driver.cc @@ -35,9 +35,9 @@ uint32_t get_half_addr(void *p, bool upper) { using vta::dpi::DPIModuleNode; using tvm::runtime::Module; -class TestDriver { +class Device { public: - TestDriver(Module module) + Device(Module module) : module_(module) { dpi_ = static_cast( module.operator->()); @@ -71,7 +71,6 @@ class TestDriver { } } - private: DPIModuleNode* dpi_; Module module_; }; @@ -84,7 +83,7 @@ TVM_REGISTER_GLOBAL("tvm.vta.driver") Module dev_mod = args[0]; DLTensor* A = args[1]; DLTensor* B = args[2]; - TestDriver dev_(dev_mod); + Device dev_(dev_mod); dev_.Run(A->shape[0], A->data, B->data); }); diff --git a/vta/apps/tsim_example/tests/python/test_tsim.py b/vta/apps/tsim_example/tests/python/add_by_one.py similarity index 90% rename from vta/apps/tsim_example/tests/python/test_tsim.py rename to vta/apps/tsim_example/tests/python/add_by_one.py index fd032f91914e..6e0d094367b5 100644 --- a/vta/apps/tsim_example/tests/python/test_tsim.py +++ b/vta/apps/tsim_example/tests/python/add_by_one.py @@ -17,7 +17,8 @@ import tvm import numpy as np -from tsim.load import load_driver, load_tsim + +from tsim.driver import driver def test_tsim(i): rmin = 1 # min vector size of 1 @@ -26,13 +27,13 @@ def test_tsim(i): ctx = tvm.cpu(0) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) - tsim = load_tsim() - f = load_driver() - f(tsim, a, b) + f = driver("libhw", "libsw") + f(a, b) emsg = "[FAIL] test number:{} n:{}".format(i, n) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1, err_msg=emsg) print("[PASS] test number:{} n:{}".format(i, n)) if __name__ == "__main__": - for i in range(10): + times = 10 + for i in range(times): test_tsim(i) From c1745377b01902ff54abe251e2a9b43953a90e06 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 20 May 2019 10:07:01 -0700 Subject: [PATCH 021/176] [BugFix] Fix bug in cast to bool (#3207) --- src/codegen/llvm/codegen_llvm.cc | 8 ++++++ topi/tests/python/test_topi_math.py | 44 +++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 7946f906125f..bedcdc79ff1f 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); + } else if (to.is_uint() && to.bits() == 1) { + if (from.is_float()) { + llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.); + return builder_->CreateFCmpONE(value, zero); + } else { + llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0); + return builder_->CreateICmpNE(value, zero); + } } else if (!from.is_float() && !to.is_float()) { return builder_->CreateIntCast(value, target, from.is_int()); } else if (from.is_float() && to.is_int()) { diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index c180bc77e829..d6df450628d2 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -19,6 +19,7 @@ import topi import topi.testing from topi import util +from common import get_all_backend def test_util(): @@ -59,8 +60,7 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel', - 'aocl_sw_emu']: + for device in get_all_backend(): check_device(device) @@ -77,6 +77,46 @@ def check_device(device): test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True) + +def test_cast(): + def verify(from_dtype, to_dtype, low=-100, high=100): + shape = (5, 4) + A = tvm.placeholder(shape, dtype=from_dtype, name="A") + B = topi.cast(A, to_dtype) + + if from_dtype == "bool": + a_np = np.random.choice([True, False], size=shape) + else: + a_np = np.random.uniform(low, high, size=shape).astype(from_dtype) + if to_dtype == "bool": + a_np = a_np - a_np[2, 3] + b_np = a_np.astype(to_dtype) + + for device in get_all_backend(): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + foo = tvm.build(s, [A, B], device) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx) + foo(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np) + + verify("int32", "float32") + verify("int32", "float64") + verify("int32", "bool") + verify("float32", "int32") + verify("float32", "float64") + verify("float32", "bool") + verify("bool", "float32") + verify("bool", "int32") + + if __name__ == "__main__": test_util() test_ewise() + test_cast() From c5549b1228f6f4d4d51cbf291dac9c3d73582ad1 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 20 May 2019 10:07:32 -0700 Subject: [PATCH 022/176] [Relay][ONNX] fix #3134 converter where initializers were not registered as nodes (#3143) --- python/tvm/relay/frontend/onnx.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 470f4197c908..c70f5aba39fe 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -934,7 +934,7 @@ def __init__(self, shape, dtype): self._renames = {} self._num_input = 0 self._num_param = 0 - self._shape = shape + self._shape = shape if shape else {} self._dtype = dtype def from_onnx(self, graph, opset): @@ -966,6 +966,9 @@ def from_onnx(self, graph, opset): if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) + self._nodes[init_tensor.name] = new_var(init_tensor.name, + shape=self._params[init_tensor.name].shape, + dtype=self._params[init_tensor.name].dtype) for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' @@ -1179,6 +1182,18 @@ def from_onnx(model, params : dict of str to tvm.NDArray The parameter dict to be used by relay """ + try: + import onnx + if hasattr(onnx.checker, 'check_model'): + # try use onnx's own model checker before converting any model + try: + onnx.checker.check_model(model) + except onnx.onnx_cpp2py_export.checker.ValidationError as e: + import warnings + # the checker is a bit violent about errors, so simply print warnings here + warnings.warn(str(e)) + except ImportError: + pass g = GraphProto(shape, dtype) graph = model.graph try: From 08fa79134bda45fc6595ef648125705679f7812f Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 20 May 2019 11:56:22 -0700 Subject: [PATCH 023/176] [Relay][TOPI] operator All (#3124) * [Relay][TOPI] operator All * Update tests/python/frontend/tensorflow/test_forward.py Co-Authored-By: yongwww <55wuyong@163.com> * fix comments * change to level 4 --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 2 + include/tvm/expr_operator.h | 7 ++ python/tvm/relay/frontend/tensorflow.py | 12 ++++ python/tvm/relay/op/_reduce.py | 1 + python/tvm/relay/op/reduce.py | 66 +++++++++++++++++-- src/lang/expr_operator.cc | 10 +++ src/relay/op/tensor/reduce.cc | 37 +++++++++++ .../frontend/tensorflow/test_forward.py | 12 ++++ tests/python/relay/test_op_level4.py | 7 +- topi/include/topi/reduction.h | 21 ++++++ topi/python/topi/reduction.py | 25 +++++++ topi/src/topi.cc | 5 ++ topi/tests/python/test_topi_reduce.py | 47 +++++++++---- 14 files changed, 232 insertions(+), 22 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index eaa5dacd678e..0b217d4fe3af 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -88,6 +88,7 @@ List of operators topi.not_equal topi.greater_equal topi.less_equal + topi.all topi.logical_and topi.logical_or topi.logical_not @@ -140,6 +141,7 @@ topi .. autofunction:: topi.gather_nd .. autofunction:: topi.full .. autofunction:: topi.full_like +.. autofunction:: topi.all .. autofunction:: topi.max .. autofunction:: topi.sum .. autofunction:: topi.min diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index cd5677293571..836f8f30bfa8 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -135,6 +135,7 @@ This level enables additional math and transform operators. tvm.relay.greater_equal tvm.relay.less tvm.relay.less_equal + tvm.relay.all tvm.relay.logical_and tvm.relay.logical_or tvm.relay.logical_not @@ -277,6 +278,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less_equal +.. autofunction:: tvm.relay.all .. autofunction:: tvm.relay.logical_and .. autofunction:: tvm.relay.logical_or .. autofunction:: tvm.relay.logical_not diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 2e1348e00470..f289bdd810d5 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -428,6 +428,13 @@ TVM_DLL Expr abs(Expr x); */ TVM_DLL Expr sum(Expr source, Array axis); +/*! + * \brief logical And of of source expression over axis + * \param source The source expression. + * \param axis List of iteration variables that will be used for reduction. + */ +TVM_DLL Expr all(Expr source, Array axis); + /*! * \brief max of of source expression over axis * \param source The source expression. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 11026b9e5ad8..7fe82ea7eac1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -767,6 +767,17 @@ def _impl(inputs, attr, params): ignores=['name', 'Tidx'])([inputs[0]], attr) return _impl +def _reduce_all(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[1].name_hint).asnumpy() + axis = tuple(axis) + return AttrCvt( + op_name='all', + extras={'axis': axis}, + transforms={'keep_dims':'keepdims'}, + ignores=['name', 'Tidx'])([inputs[0]], attr) + return _impl + def _square(): def _impl(inputs, attr, params): return _op.multiply(inputs[0], inputs[0]) @@ -1180,6 +1191,7 @@ def _impl(inputs, attr, params): # for N to 1 mapping, currently not supported(?) _convert_map = { 'Add' : _elemwise('add'), + 'All' : _reduce_all(), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index b97e3a8ce993..b7c9a79a8ad9 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("argmax", _schedule_reduce) _reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce) +_reg.register_schedule("all", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce) _reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 9d58a92041f3..0f2594600b0a 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False): return _make.sum(data, axis, keepdims, exclude) +def all(data, axis=None, keepdims=False, exclude=False): + """Computes the logical AND of boolean array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a sum is performed. The default, axis=None, + will sum all of the elements of the input array. If axis is + negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast + correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + data = relay.Constant(tvm.nd.array([[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]])) + + relay.all(data, axis=1) + # [[False, True, False], + # [False, False, False]] + + relay.all(data, axis=0) + # [[ True, False, False], + # [ True, True, False], + # [False, True, False]] + + """ + axis = [axis] if axis and isinstance(axis, int) else axis + return _make.all(data, axis, keepdims, exclude) + + def max(data, axis=None, keepdims=False, exclude=False): """ Computes the max of array elements over given axes. @@ -131,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -163,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -194,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -225,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 4504ee23f812..8537f17b763c 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -393,6 +393,16 @@ Expr sum(Expr source, Array rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr all(Expr source, Array rdom) { + CHECK(source.type().is_bool()); + Var x("x", source.type()), y("y", source.type()); + Expr result = ir::And::make(x, y); + Expr identity_element = make_const(source.type(), true); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); +} + Expr max(Expr source, Array rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Max::make(x, y); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index a4ebd1e8d050..647e4d0f4f90 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -355,6 +355,43 @@ Example:: .set_attr("TOpPattern", kCommReduce); +Array AllCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::all); +} + + +RELAY_REGISTER_REDUCE_OP("all") +.describe(R"code(Computes the logical AND of boolean array elements over given axes. + +Example:: + + data = [[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]] + + all(data, axis=1) + [[False, True, False], + [False, False, False]] + + all(data, axis=0) + [[ True, False, False], + [ True, True, False], + [False, True, False]] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", AllCompute) +.set_attr("TOpPattern", kCommReduce); + + Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index e4626e0d60ff..023cdf5eb261 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1597,6 +1597,17 @@ def check_mean(ishape, **kwargs): check_mean((10, 8, 16, 32), axis=(2,3)) check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) +####################################################################### +# All +# --- +def test_forward_all(): + """Test the All operator.""" + np_data = np.random.choice([True, False], size=(5, 7, 11)) + tf.reset_default_graph() + in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data") + tf.reduce_all(in_data, name="all") + compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') + ####################################################################### # Relational operators # -------------------- @@ -1718,6 +1729,7 @@ def test_placeholder(): test_forward_reduce() test_forward_mean() test_forward_reduce_prod() + test_forward_all() # General test_forward_multi_input() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 0e44bf851dc4..aac4a6d4af16 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -138,6 +138,7 @@ def test_where(): def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): test_func = funcs[0] ref_func = funcs[1] + dtype = "bool" if ref_func in [np.all] else dtype x = relay.var("x", relay.TensorType(data, dtype)) z = test_func(x, axis, keepdims, exclude) @@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") return func = relay.Function([x], z) - x_data = np.random.uniform(size=data).astype(dtype) + x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \ + else np.random.uniform(size=data).astype(dtype) + if ref_func in [np.sum]: ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims) elif ref_func in [np.max, np.min, np.mean, np.prod]: @@ -194,6 +197,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.min, np.min], [relay.mean, np.mean], [relay.prod, np.prod], + [relay.all, np.all], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) @@ -203,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False): verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) + verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1)) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (4, 4, 3), None, False, False, ()) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index b24c4577c4e5..09d1b4b1b33e 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -368,6 +368,27 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes); } +/*! +* \brief Creates an operation that computes the logical AND of elements +* over a given axis +* +* \param data The input boolean tensor +* \param axis The axes to reduce. If axis is empty, the operation will +* perform logical AND over all elements of the array. +* \param keepdims If this is set to true, the axes which are reduced are +* left in the result as dimensions with size one. This enables the result +* to broadcast correctly against the input array. +* \param atleast1d Whether the output need to be atleast1d. +* +* \return A Tensor whose op member is the all operation +*/ +inline Tensor all(const Tensor& data, + const Array& axis, + bool keepdims = false, + bool atleast1d = false) { + return CommReduce(data, axis, tvm::all, keepdims, atleast1d); +} + /*! * \brief Creates an operation that finds the minimum of elements over * a given axis. diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py index ce1326b78162..5079bf474deb 100644 --- a/topi/python/topi/reduction.py +++ b/topi/python/topi/reduction.py @@ -65,6 +65,31 @@ def sum(data, axis=None, keepdims=False): return cpp.sum(data, axis, keepdims) +def all(data, axis=None, keepdims=False): + """Logical AND of array elements over a given axis or a list of axes + + Parameters + ---------- + data : tvm.Tensor + The input tvm boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a logical AND is performed. + The default, axis=None, will perform logical AND over all elements of the input array. + If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + Returns + ------- + ret : tvm.Tensor + """ + return cpp.all(data, axis, keepdims) + + def max(data, axis=None, keepdims=False): """Maximum of array elements over a given axis or a list of axes diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 1585d877b625..d3e0bc938f7c 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -265,6 +265,11 @@ TVM_REGISTER_GLOBAL("topi.prod") *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.all") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); + }); + /* Ops from transform.h */ TVM_REGISTER_GLOBAL("topi.expand_dims") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 1882cbd7f896..6e6470dad588 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -50,6 +50,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") out_dtype = dtype if type == "sum": B = topi.sum(A1, axis=axis, keepdims=keepdims) + elif type == "all": + B = topi.all(A, axis=axis, keepdims=keepdims) elif type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) elif type == "min": @@ -74,10 +76,16 @@ def check_device(device): foo = tvm.build(s, [A, B], device, name=type) # Test - in_npy = np.random.uniform(size=in_shape).astype(dtype) - in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) + if dtype == 'bool': + in_npy_map = in_npy = np.random.choice([True, False], size=in_shape) + else: + in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) + in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) + if type == "sum": out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) + elif type == "all" and dtype == 'bool': + out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) elif type == "max": out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) elif type == "min": @@ -113,26 +121,37 @@ def check_device(device): def test_reduce_map(): + verify_reduce_map_ele(in_shape=(32,), axis=0, keepdims=False, type="argmax") verify_reduce_map_ele(in_shape=(128, 24, 128, 24), - axis=(1, 2, 3), - keepdims=True, - type="sum") + axis=(1, 2, 3), + keepdims=True, + type="sum") + verify_reduce_map_ele(in_shape=(2, 3), + axis=None, + keepdims=True, + type="all", + dtype='bool') verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), - axis=(1,), - keepdims=False, - type="max") + axis=(1,), + keepdims=False, + type="max") + verify_reduce_map_ele(in_shape=(32, 128, 24), + axis=None, + keepdims=True, + type="sum") verify_reduce_map_ele(in_shape=(32, 128, 24), - axis=None, - keepdims=True, - type="sum") + axis=None, + keepdims=True, + dtype='bool', + type="all") verify_reduce_map_ele(in_shape=(128, 24, 128, 24), - axis=(0, 2), - keepdims=False, - type="min") + axis=(0, 2), + keepdims=False, + type="min") verify_reduce_map_ele(in_shape=(32, 128), axis=1, keepdims=True, From 00a0099f8b49e3994442394e2202d51eca8e2042 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 21 May 2019 01:03:34 -0700 Subject: [PATCH 024/176] Add bing to reviewer (#3214) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4d4515e09410..7e0ad806c7ec 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -94,6 +94,7 @@ We do encourage everyone to work anything they are interested in. - [Leyuan Wang](https://github.com/Laurawly): @Laurawly - [Jian Weng](https://github.com/were): @were - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene +- [Bing Xu](https://github.com/antinucleon): @antinucleon - [Eddie Yan](https://github.com/eqy): @eqy - [Joshua Z. Zhang](https://github.com/zhreshold): @zhreshold - [Lianmin Zheng](https://github.com/merrymercy): @merrymercy From 04d2b0418f293aad4cacbb948109b11d0fba3328 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 21 May 2019 12:53:03 -0700 Subject: [PATCH 025/176] [relay][vm] remove throw in destructor (#3215) --- src/runtime/vm/vm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b2d326ec7792..6f9190e8907a 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -226,8 +226,8 @@ Instruction::~Instruction() { return; default: std::ostringstream out; - out << "Invalid instruction " << static_cast(this->op); - throw std::runtime_error(out.str()); + LOG(FATAL) << "Invalid instruction " << static_cast(this->op) + << "\n"; } } From 1bf6184bceef03620983f86e86bfc108889c52db Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 21 May 2019 12:53:58 -0700 Subject: [PATCH 026/176] [Relay][heterogeneous pass] remove on_device op after annotation (#3204) * remove on_device op after annotation * Update src/relay/pass/device_annotation.cc Co-Authored-By: MORINAGA <34588258+imorinaga@users.noreply.github.com> --- src/relay/pass/device_annotation.cc | 47 ++++++++++++++++- tests/python/relay/test_pass_annotation.py | 61 ++++++++++++++-------- 2 files changed, 85 insertions(+), 23 deletions(-) diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 0139cc912849..8807f6dd4cf4 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -485,7 +485,52 @@ class DeviceInfo { Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { RewriteAnnotation rewrote = RewriteAnnotation(); - return rewrote.Rewrite(expr, fallback_device); + Expr new_expr = rewrote.Rewrite(expr, fallback_device); + + // Remove OnDevice operators. Note that these operators are only present at the + // leaves after annotation. Therefore, we can simply reconstruct the + // Function/Expr by removing them directly. + if (const FunctionNode* fn = new_expr.as()) { + auto params = fn->params; + auto body = fn->body; + std::vector new_body; + if (const TupleNode* tuple = body.as()) { + for (const auto& field : tuple->fields) { + if (!IsOnDeviceNode(field.operator->())) { + new_body.push_back(field); + } + } + CHECK_GT(new_body.size(), 0U); + if (new_body.size() == 1) { + return FunctionNode::make(params, new_body[0], Type(nullptr), + fn->type_params, fn->attrs); + } else if (tuple->fields.size() == new_body.size()) { + return new_expr; + } else { + Tuple tuple_body = TupleNode::make(new_body); + return FunctionNode::make(params, tuple_body, Type(nullptr), + fn->type_params, fn->attrs); + } + } else { + return new_expr; + } + } else if (const TupleNode* tuple = new_expr.as()) { + std::vector new_fields; + for (const auto& field : tuple->fields) { + if (!IsOnDeviceNode(field.operator->())) { + new_fields.push_back(field); + } + } + CHECK_GT(new_fields.size(), 0U); + if (tuple->fields.size() == new_fields.size()) { + return new_fields.size() == 1 ? new_fields[0] : new_expr; + } else { + return new_fields.size() == 1 ? new_fields[0] + : TupleNode::make(new_fields); + } + } else { + return new_expr; + } } Map CollectDeviceInfo(const Expr& expr) { diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 9a77d2ffe856..98cf0f15446e 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -42,9 +42,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[2]), - func.body[2]) + return func def expected(): add = relay.add(x, y) @@ -58,6 +56,35 @@ def expected(): assert relay.ir_pass.alpha_equal(annotated_func, expected_func) +def test_annotate_expr(): + ctx1 = tvm.context(1) + ctx2 = tvm.context(2) + x = relay.var("x", shape=(3,)) + y = relay.var("y", shape=(3,)) + z = relay.var("z", shape=(3,)) + + def annotated(): + add = relay.add(x, y) + _add = relay.annotation.on_device(add, ctx1) + sub = relay.subtract(add, z) + _sub = relay.annotation.on_device(sub, ctx2) + expr = relay.Tuple([sub, _add, _sub]) + expr = relay.ir_pass.infer_type(expr) + expr = relay.ir_pass.rewrite_annotated_ops(expr, + ctx1.device_type) + return expr + + def expected(): + add = relay.add(x, y) + copy_add_sub = relay.device_copy(add, ctx1, ctx2) + sub = relay.subtract(copy_add_sub, z) + return sub + + annotated_expr = relay.ir_pass.infer_type(annotated()) + expected_expr = relay.ir_pass.infer_type(expected()) + assert relay.ir_pass.graph_equal(annotated_expr, expected_expr) + + def test_annotate_all(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) @@ -77,9 +104,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[2]), - func.body[2]) + return func def expected(): add = relay.add(x, y) @@ -91,6 +116,7 @@ def expected(): expected_func = relay.ir_pass.infer_type(expected()) assert relay.ir_pass.alpha_equal(annotated_func, expected_func) + def test_annotate_none(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) @@ -174,9 +200,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, tvm.context(3).device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[4]), - func.body[4]) + return func def expected(): conv2d_1 = relay.nn.conv2d( @@ -202,7 +226,7 @@ def expected(): kernel_size=(3, 3), padding=(1, 1)) - func = relay.Function([data1, weight, data2], conv2d_3) + func = relay.Function([data1, data2, weight], conv2d_3) return func def check_storage_and_device_types(): @@ -306,9 +330,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[2]), - func.body[2]) + return func def expected(): add = relay.add(x, y) @@ -358,9 +380,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[5]), - func.body[5]) + return func annotated_func = annotated() expected_func = get_func() @@ -386,9 +406,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[1]), - func.body[1]) + return func def expected(): add = relay.add(x, y) @@ -462,9 +480,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[3]), - func.body[3]) + return func def expected(): add = relay.add(a, b) @@ -506,6 +522,7 @@ def test_check_run(): if __name__ == "__main__": test_redundant_annotation() + test_annotate_expr() test_annotate_all() test_annotate_none() test_conv_network() From 5a798c8d744d60962390299231da77a5067f9a43 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Tue, 21 May 2019 16:05:28 -0700 Subject: [PATCH 027/176] [Contrib] cblas batch_matmul (#3210) --- cmake/modules/contrib/BLAS.cmake | 6 +- python/tvm/contrib/cblas.py | 55 +++++++++- src/contrib/cblas/cblas.cc | 171 ++++++++++++++++++++++------- src/contrib/cblas/gemm_common.h | 128 +++++++++++++++------ tests/python/contrib/test_cblas.py | 83 ++++++++++++-- 5 files changed, 351 insertions(+), 92 deletions(-) diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake index e1e151d6a9f8..a47f83771d37 100644 --- a/cmake/modules/contrib/BLAS.cmake +++ b/cmake/modules/contrib/BLAS.cmake @@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl") if(NOT IS_DIRECTORY ${USE_MKL_PATH}) set(USE_MKL_PATH /opt/intel/mkl) endif() - find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + if(APPLE) + find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + elseif(UNIX) + find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + endif() include_directories(${USE_MKL_PATH}/include) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index c656fcc2b966..7c024b792867 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -17,10 +17,10 @@ """External function interface to BLAS libraries.""" from __future__ import absolute_import as _abs -from .. import api as _api -from .. import intrin as _intrin +from .. import api as _api, intrin as _intrin -def matmul(lhs, rhs, transa=False, transb=False): + +def matmul(lhs, rhs, transa=False, transb=False, **kwargs): """Create an extern op that compute matrix mult of A and rhs with CrhsLAS This function serves as an example on how to call external libraries. @@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False): n = lhs.shape[1] if transa else lhs.shape[0] m = rhs.shape[0] if transb else rhs.shape[1] return _api.extern( - (n, m), [lhs, rhs], + (n, m), + [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): + """Create an extern op that compute batched matrix mult of A and rhs with CBLAS + This function serves as an example on how to call external libraries. + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return _api.extern( + (b, n, m), + [lhs, rhs], lambda ins, outs: _intrin.call_packed( - "tvm.contrib.cblas.matmul", - ins[0], ins[1], outs[0], transa, transb), name="C") + "tvm.contrib.cblas.batch_matmul" + if not iterative + else "tvm.contrib.cblas.batch_matmul_iterative", + ins[0], + ins[1], + outs[0], + transa, + transb, + ), + name="C", + **kwargs + ) diff --git a/src/contrib/cblas/cblas.cc b/src/contrib/cblas/cblas.cc index 4ca043f1bcfe..0f222e2f2a39 100644 --- a/src/contrib/cblas/cblas.cc +++ b/src/contrib/cblas/cblas.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,12 +21,11 @@ * Copyright (c) 2017 by Contributors * \file Use external cblas library call. */ +#include #include #include -#include #include "gemm_common.h" - extern "C" { #if USE_MKL_BLAS == 1 #include @@ -40,56 +39,148 @@ namespace contrib { using namespace runtime; -inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { - return trans ? CblasTrans : CblasNoTrans; -} +inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } struct CblasSgemmOp { typedef float TDatatype; - void operator()(bool ta, bool tb, - int M, int N, int K, - float alpha, float* A, int lda, - float* B, int ldb, - float beta, float* C, int ldc) { - cblas_sgemm(CblasColMajor, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - alpha, A, lda, - B, ldb, - beta, C, ldc); + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } }; struct CblasDgemmOp { typedef double TDatatype; - void operator()(bool ta, bool tb, - int M, int N, int K, - double alpha, double* A, int lda, - double* B, int ldb, - double beta, double* C, int ldc) { - cblas_dgemm(CblasColMajor, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - alpha, A, lda, - B, ldb, - beta, C, ldc); + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } }; +struct CblasSgemmBatchOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); +#if USE_MKL_BLAS == 1 + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); +#else + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } +#endif + } +}; + +struct CblasSgemmBatchIterativeOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; + +struct CblasDgemmBatchOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); +#if USE_MKL_BLAS == 1 + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); +#else + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } +#endif + } +}; + +struct CblasDgemmBatchIterativeOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); +}); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); - }); +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchOp()); + } +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/cblas/gemm_common.h b/src/contrib/cblas/gemm_common.h index fe38b2a67513..2bcefb2f26bb 100644 --- a/src/contrib/cblas/gemm_common.h +++ b/src/contrib/cblas/gemm_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,16 +22,17 @@ * \file tvm/contrib/gemm.h * \brief Shared implementation of gemm */ -#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ -#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ +#pragma once + +#include +#include #include namespace tvm { namespace contrib { using namespace runtime; - -inline int ColumnStride(DLTensor* tensor) { +inline int ColumnStride(DLTensor *tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) { } } - -inline int ElementStride(DLTensor* tensor) { +inline int ElementStride(DLTensor *tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) { } } - // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor* tensor) { +inline bool IsInPlaceTransposed(DLTensor *tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } - -inline int RowCount(DLTensor* tensor, bool trans) { +inline int RowCount(DLTensor *tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } - -inline int ColumnCount(DLTensor* tensor, bool trans) { +inline int ColumnCount(DLTensor *tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. -template +template inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; @@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - op(transb, - transa, - ColumnCount(B, transb), - RowCount(A, transa), - ColumnCount(A, transa), - static_cast(alpha), - reinterpret_cast(static_cast(B->data) - + B->byte_offset), + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), + ColumnCount(A, transa), static_cast(alpha), + reinterpret_cast( + static_cast(B->data) + B->byte_offset), ColumnStride(B), - reinterpret_cast(static_cast(A->data) - + A->byte_offset), - ColumnStride(A), - static_cast(beta), - reinterpret_cast(static_cast(C->data) - + C->byte_offset), + reinterpret_cast( + static_cast(A->data) + A->byte_offset), + ColumnStride(A), static_cast(beta), + reinterpret_cast( + static_cast(C->data) + C->byte_offset), ColumnStride(C)); } +inline int ColumnStride3D(DLTensor *tensor) { + // If the tensor itself is transposed then it will have strides + // backward from what we expect. Regardless, the max of the strides + // (the other stride is 1) is the column stride. + if (tensor->strides) { + return std::max(tensor->strides[1], tensor->strides[2]); + } else { + return tensor->shape[2]; + } +} +inline int ElementStride3D(DLTensor *tensor) { + if (tensor->strides) { + return std::min(tensor->strides[1], tensor->strides[2]); + } else { + return 1; + } +} +// Reversed strides indicates an in-place transpose operation. +inline bool IsInPlaceTransposed3D(DLTensor *tensor) { + return tensor->strides && (tensor->strides[2] > tensor->strides[1]); +} +inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; } +inline int RowCount3D(DLTensor *tensor, bool trans) { + return tensor->shape[trans ? 2 : 1]; +} +inline int ColumnCount3D(DLTensor *tensor, bool trans) { + return tensor->shape[trans ? 1 : 2]; +} +template +inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { + using DType = typename TBatchGemmOp::TDatatype; + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + int bit_depth = sizeof(DType) * 8; + CHECK_EQ(A->ndim, 3); + CHECK_EQ(B->ndim, 3); + CHECK_EQ(C->ndim, 3); + int batch_size = BatchCount3D(A); + CHECK_EQ(BatchCount3D(B), batch_size); + CHECK_EQ(BatchCount3D(C), batch_size); + CHECK_EQ(ElementStride(A), 1); + CHECK_EQ(ElementStride(B), 1); + CHECK_EQ(ElementStride(C), 1); + // C can never be transposed. + CHECK(!IsInPlaceTransposed3D(C)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed3D(A) ? !transa : transa; + transb = IsInPlaceTransposed3D(B) ? !transb : transb; + CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); + CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + const int A_size = A->shape[1] * A->shape[2]; + const int B_size = B->shape[1] * B->shape[2]; + const int C_size = C->shape[1] * C->shape[2]; + DType *A_data = reinterpret_cast( + static_cast(A->data) + A->byte_offset); + DType *B_data = reinterpret_cast( + static_cast(B->data) + B->byte_offset); + DType *C_data = reinterpret_cast( + static_cast(C->data) + C->byte_offset); + op(batch_size, transb, transa, ColumnCount3D(B, transb), + RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast(alpha), + B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), + static_cast(beta), C_data, C_size, ColumnStride3D(C)); +} + } // namespace contrib } // namespace tvm - -#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index 6705328ee50a..808c07a2e602 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -16,19 +16,26 @@ # under the License. import tvm import numpy as np +import topi.testing from tvm.contrib import cblas -def test_matmul_add(): - n = 1024 - l = 128 - m = 235 - bias = tvm.var('bias', dtype=tvm.float32) - A = tvm.placeholder((n, l), name='A') - B = tvm.placeholder((l, m), name='B') - C = cblas.matmul(A, B) +def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32): + bias = tvm.var('bias', dtype=dtype) + ashape = (l, n) if transa else (n, l) + bshape = (m, l) if transb else (l, m) + A = tvm.placeholder(ashape, name='A', dtype=dtype) + B = tvm.placeholder(bshape, name='B', dtype=dtype) + C = cblas.matmul(A, B, transa, transb) D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") s = tvm.create_schedule(D.op) + def get_numpy(a, b, bb, transa, transb): + if transa: + a = a.transpose() + if transb: + b = b.transpose() + return np.dot(a, b) + bb + def verify(target="llvm"): if not tvm.module.enabled(target): print("skip because %s is not enabled..." % target) @@ -38,15 +45,69 @@ def verify(target="llvm"): return ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D, bias], target) - a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) bb = 10.0 f(a, b, d, bb) tvm.testing.assert_allclose( - d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) + d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5) + verify() + +def test_matmul_add(): + verify_matmul_add(235, 128, 1024) + verify_matmul_add(235, 128, 1024, True, False) + verify_matmul_add(235, 128, 1024, False, True) + verify_matmul_add(235, 128, 1024, True, True) + verify_matmul_add(1, 16, 4) + verify_matmul_add(1, 16, 3, True, False) + verify_matmul_add(1, 16, 3, False, False) + verify_matmul_add(1, 16, 3, True, True) + +def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32): + ashape = (batch, l, n) if transa else (batch, n, l) + bshape = (batch, m, l) if transb else (batch, l, m) + A = tvm.placeholder(ashape, name='A', dtype=dtype) + B = tvm.placeholder(bshape, name='B', dtype=dtype) + C = cblas.batch_matmul(A, B, transa, transb) + D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D") + s = tvm.create_schedule(D.op) + + def get_numpy(a, b, transa, transb): + if transa: + a = a.transpose(0, 2, 1) + if not transb: + b = b.transpose(0, 2, 1) + return topi.testing.batch_matmul(a, b) + + def verify(target="llvm"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.cpu(0) + f = tvm.build(s, [A, B, D], target) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx) + f(a, b, d) + tvm.testing.assert_allclose( + d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5) verify() +def test_batch_matmul(): + verify_batch_matmul(16, 235, 128, 1024) + verify_batch_matmul(16, 235, 128, 1024, True, False) + verify_batch_matmul(16, 235, 128, 1024, False, True) + verify_batch_matmul(16, 235, 128, 1024, True, True) + verify_batch_matmul(1, 1, 16, 3) + verify_batch_matmul(1, 1, 16, 3, True, False) + verify_batch_matmul(1, 1, 16, 3, False, False) + verify_batch_matmul(1, 1, 16, 3, True, True) + verify_batch_matmul(1, 1, 16, 3, iterative=True) if __name__ == "__main__": test_matmul_add() + test_batch_matmul() From ced51108a4d3cc6d2ad5d959338fead80dda5dfc Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Tue, 21 May 2019 16:34:35 -0700 Subject: [PATCH 028/176] Add `SkipVectorize` pass (#3222) --- docs/api/python/dev.rst | 1 + include/tvm/build_module.h | 4 ++++ include/tvm/ir_pass.h | 35 +++++++++++++++++++++-------------- python/tvm/build_module.py | 8 ++++++-- src/codegen/build_module.cc | 7 ++++++- src/pass/vectorize_loop.cc | 22 ++++++++++++++++++++-- 6 files changed, 58 insertions(+), 19 deletions(-) diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index e4b207bf4cbc..7bb938ca7517 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -61,6 +61,7 @@ tvm.ir_pass tvm.ir_pass.CanonicalSimplify tvm.ir_pass.StorageFlatten tvm.ir_pass.VectorizeLoop + tvm.ir_pass.SkipVectorize tvm.ir_pass.UnrollLoop tvm.ir_pass.ThreadSync tvm.ir_pass.StorageRewrite diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 208f086f86c0..7fb456c823a7 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -246,6 +246,9 @@ class BuildConfigNode : public Node { /*! \brief Whether to disable select rewriting. */ bool disable_select_rewriting = false; + /*! \brief Whether to disable loop vectorization. */ + bool disable_vectorize = false; + void VisitAttrs(AttrVisitor* v) final { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); @@ -260,6 +263,7 @@ class BuildConfigNode : public Node { v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("disable_select_rewriting", &disable_select_rewriting); + v->Visit("disable_vectorize", &disable_vectorize); } static constexpr const char* _type_key = "BuildConfig"; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ef4dc4ed9d7..e1c92e50e6ad 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt, /*! * \brief vectorize the constant loops - * \param stmt The statment to be vectorized. + * \param stmt The statement to be vectorized. * \return Transformed stmt. */ Stmt VectorizeLoop(Stmt stmt); +/*! + * \brief convert vectorized loops into serialized loops + * \param stmt The statement to skip vectorization on. + * \return Transformed stmt. + */ +Stmt SkipVectorize(Stmt stmt); + /*! * \brief instruments bound checkers. -* \param stmt The statment to be instrumented. -* \return Instrumented Stmt. +* \param stmt The statement to be instrumented. +* \return Instrumented stmt. */ Stmt InstrumentBoundCheckers(Stmt stmt); /*! * \brief Inject virtual thread loops into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \return Transformed stmt. */ Stmt InjectVirtualThread(Stmt stmt); /*! * \brief Inject prefetch instructions into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \return Transformed stmt. */ Stmt InjectPrefetch(Stmt stmt); /*! * \brief Inject double buffer into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \param split_loop Loop splitting factor. * \return Transformed stmt. */ @@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); /*! * \brief Inject copy intrinsics with optional pad. * - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \param pragma_key The pragma key for hint of copy. * \param fintrin The function with signature * @@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt, * Trying to share space between allocations to make * a static allocation plan when possible. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt StorageRewrite(Stmt stmt); @@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop); /*! * \brief Detect and insert sync points to co-processor. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt CoProcSync(Stmt stmt); @@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt); /*! * \brief Lift common attrs with attr_key to outer scope. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \param attr_key The attribute key to be checked. * \return Transformed stmt. */ @@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); /*! * \brief Detect and rewrite unsafe select that contains memory access. - * \param stmt The statment to be rewritten. + * \param stmt The statement to be rewritten. * \return Transformed stmt. */ Stmt RewriteUnsafeSelect(Stmt stmt); @@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt); * \brief Lower attached storage access information. * Do this pass after all storage access analysis finish. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt LowerStorageAccessInfo(Stmt stmt); @@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt); * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt DecorateDeviceScope(Stmt stmt); @@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt); * \return a LoweredFunc with the specified signiture. * * \note - * The function signiture have two cases + * The function signature have two cases * * let num_packed_args = len(api_args) - num_unpacked_args; * diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 120bf629a959..a28ab98fb60e 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -143,7 +143,8 @@ class BuildConfig(NodeBase): "double_buffer_split_loop": 1, "dump_pass_ir": False, "instrument_bound_checkers": False, - "disable_select_rewriting": False + "disable_select_rewriting": False, + "disable_vectorize": False } _dump_ir = DumpIR() @@ -384,7 +385,10 @@ def lower(sch, # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - stmt = ir_pass.VectorizeLoop(stmt) + if cfg.disable_vectorize: + stmt = ir_pass.SkipVectorize(stmt) + else: + stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 9b30ced90c4f..ac6b797d9683 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch, if (loop_partition) { stmt = ir::LoopPartition(stmt, config->partition_const_loop); } - stmt = ir::VectorizeLoop(stmt); + if (config->disable_vectorize) { + stmt = ir::SkipVectorize(stmt); + } else { + stmt = ir::VectorizeLoop(stmt); + } stmt = ir::InjectVirtualThread(stmt); stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); stmt = ir::StorageRewrite(stmt); @@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; + p->stream << "disable_vectorize=" << op->disable_vectorize; p->stream << ")"; }); diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index f87e80c2d030..8c3d383c1529 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer().Mutate(stmt); } +class VectorizeSkipper : public IRMutator { + public: + Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op->for_type == ForType::Vectorized) { + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, + op->body); + } else { + return stmt; + } + } +}; + +Stmt SkipVectorize(Stmt stmt) { + return VectorizeSkipper().Mutate(stmt); +} + } // namespace ir } // namespace tvm From f279c101753e975cdeab2dabe68377b9624cec9a Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Wed, 22 May 2019 10:44:47 +0800 Subject: [PATCH 029/176] [TFLite] Convert TFLite NCHW to NHWC (#3141) * Convert TFLite NCHW to NHWC * Minor comment fix --- python/tvm/relay/frontend/tflite.py | 120 +++--------------- tests/python/frontend/tflite/test_forward.py | 123 ++++--------------- tutorials/frontend/from_tflite.py | 19 +-- 3 files changed, 41 insertions(+), 221 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ff62d89412e9..bfd63bb0140e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -209,44 +209,10 @@ def convert_reshape(self, op): reshape_options = ReshapeOptions() reshape_options.Init(op_options.Bytes, op_options.Pos) target_shape = reshape_options.NewShapeAsNumpy() - input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) - - if input_shape_length in (1, 2): - # The rule is channel first (after N but before H, W). - # length of 1 means N*H*W*C, do nothing. - # length of 2 means N*H*W, C, do nothing. - pass - elif input_shape_length == 3: - # convert N C H*W to N H*W C - in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # convert input to N H W C, then reshape to target shape, - # finally convert back if necessary - in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) - else: - msg = 'Input shape length {} for operator Reshape is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.reshape(in_expr, newshape=tuple(target_shape)) - # The rule is channel first. - # 1: N*H*W*C - # 2: N*H*W, C - # 3: N H W C, reshape to N H*W C, transpose to N C H*W - # 4: N H W C, transpose to N C H W - # add more if we need target shapes in future - if len(target_shape) == 1 or len(target_shape) == 2: - pass - elif len(target_shape) == 3: - out = _op.transpose(out, axes=(0, 2, 1)) - elif len(target_shape) == 4: - out = _op.transpose(out, axes=(0, 3, 1, 2)) - else: - raise tvm.error.OpAttributeInvalid( - 'Length of target shape must be between 1 and 5 for operator Reshape.') - return out def convert_softmax(self, op): @@ -269,7 +235,7 @@ def convert_softmax(self, op): return out def convert_concatenation(self, op): - """ convert TFLite concatenation""" + """Convert TFLite concatenation""" try: from tflite.Operator import Operator from tflite.ConcatenationOptions import ConcatenationOptions @@ -292,15 +258,6 @@ def convert_concatenation(self, op): concatenation_options.Init(op_options.Bytes, op_options.Pos) concatenation_axis = concatenation_options.Axis() fused_activation_fn = concatenation_options.FusedActivationFunction() - input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy()) - - # TFLite is N H W C, our layout is N C H W - if input_shape_length <= 4: - axis_convert_map = [0] + list(range(2, input_shape_length)) + [1] - concatenation_axis = axis_convert_map[concatenation_axis] - else: - raise NotImplementedError("Not support input shape length {} of concatenatio : " - .format(str(input_shape_length))) # with axis in N H W C out = _op.concatenate(in_exprs, axis=concatenation_axis) @@ -336,20 +293,6 @@ def convert_add(self, op): rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) - # In this case, we have to be careful about formatting. - input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy()) - if input_shape_length in (1, 2): - pass - elif input_shape_length == 3: - # N H*W C to N C H*W - rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # N H W C to N C H W - rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2)) - else: - msg = 'Input shape length {} for operator ADD is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.add(lhs_expr, rhs_expr) return out @@ -440,46 +383,10 @@ def convert_squeeze(self, op): squeeze_options = SqueezeOptions() squeeze_options.Init(op_options.Bytes, op_options.Pos) squeeze_axis = squeeze_options.SqueezeDimsAsNumpy() - input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) - output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) - - # TFLite is N H W C, our layout is N C H W - if input_shape_length in (1, 2): - # The rule is channel first (after N but before H, W). - # length of 1 means N*H*W*C, do nothing. - # length of 2 means N*H*W, C, do nothing. - pass - elif input_shape_length == 3: - # convert N C H*W to N H*W C - in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # convert input to N H W C, then reshape to target shape, - # finally convert back if necessary - in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) - else: - msg = 'Input shape length {} for operator Squeeze is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) - # The rule is channel first. - # 1: N*H*W*C - # 2: N*H*W, C - # 3: N H W C, reshape to N H*W C, transpose to N C H*W - # 4: N H W C, transpose to N C H W - # add more if we need target shapes in future - if output_shape_length in (1, 2): - pass - elif output_shape_length == 3: - out = _op.transpose(out, axes=(0, 2, 1)) - elif output_shape_length == 4: - out = _op.transpose(out, axes=(0, 3, 1, 2)) - else: - msg = 'Output shape length {} for operator Squeeze is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length)) - return out def convert_fused_activation_function(self, in_expr, fused_activation_fn): @@ -562,13 +469,16 @@ def convert_conv(self, op, conv_type): params = {'kernel_size': [kernel_h, kernel_w], 'strides': [stride_h, stride_w], 'dilation': [dilation_h, dilation_w], - 'padding': [0, 0]} + 'padding': [0, 0], + 'data_layout': 'NHWC'} if is_depthwise_conv: params['channels'] = int(in_channels * multiplier) params['groups'] = int(in_channels) + params['kernel_layout'] = 'HWOI' else: params['channels'] = int(output_channels) + params['kernel_layout'] = 'HWIO' # weight tensor type should be UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() @@ -578,12 +488,9 @@ def convert_conv(self, op, conv_type): in_expr = self.get_expr(input_tensor_idx) weight_value = self.get_tensor_value(weight_tensor) - if is_depthwise_conv: - # TFLite is M KH KW IC, we require IC M KH KW - weight_value = weight_value.transpose((3, 0, 1, 2)) - else: - # TFLite is OC KH KW IC, we require OC IC KH kW - weight_value = weight_value.transpose((0, 3, 1, 2)) + # TFLite is OC/M KH KW IC, we require KH KW IC OC/M + # M means multiplier in depthwise convolution + weight_value = weight_value.transpose((1, 2, 3, 0)) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) @@ -592,9 +499,10 @@ def convert_conv(self, op, conv_type): elif padding == Padding.SAME: pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) - in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0), + in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (pad_top, pad_bottom), - (pad_left, pad_right))) + (pad_left, pad_right), + (0, 0))) else: raise tvm.error.OpAttributeUnimplemented( 'Padding format {} is not supported for operator Conv.'.format(padding)) @@ -610,7 +518,8 @@ def convert_conv(self, op, conv_type): bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str) - out = _op.nn.bias_add(out, bias_expr) + channel_axis = 3 + out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) # If we have fused activations if fused_activation_fn != ActivationFunctionType.NONE: @@ -648,7 +557,8 @@ def convert_pool2d(self, op, pool_type): params = {'pool_size': (filter_h, filter_w), 'strides': (stride_h, stride_w), - 'padding': [0, 0]} + 'padding': [0, 0], + 'layout': 'NHWC'} in_expr = self.get_expr(input_tensor_idx) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 63a345a5a6d5..8fc2d550d556 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -116,12 +116,10 @@ def run_tflite_graph(tflite_model_buf, input_data): return tflite_output -def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, - output_tensors, output_need_transpose=False, - init_global_variables=False): +def compare_tflite_with_tvm(in_data, in_name, input_tensors, + output_tensors, init_global_variables=False): """Generic function to generate and compare TFLite and TVM output""" - tflite_in_data = convert_to_list(tflite_in_data) - tvm_in_data = convert_to_list(tvm_in_data) + in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) in_node = [0] * len(in_name) for i in range(len(in_name)): @@ -134,7 +132,7 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, converter = tf.contrib.lite.TFLiteConverter.from_session( sess, input_tensors, output_tensors) tflite_model_buffer = converter.convert() - tflite_output = run_tflite_graph(tflite_model_buffer, tflite_in_data) + tflite_output = run_tflite_graph(tflite_model_buffer, in_data) for device in ["llvm"]: ctx = tvm.context(device, 0) @@ -142,25 +140,9 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, print("Skip because %s is not enabled" % device) continue - tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device) + tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device) for i in range(len(tflite_output)): - if output_need_transpose: - dim = len(tvm_output[i].shape) - if dim == 3: - # N C H*W to N H*W C - axes = (0, 2, 1) - elif dim == 4: - # N C H W to N H W C - axes = (0, 2, 3, 1) - else: - raise NotImplementedError("Not support input shape {} of transpose : ". - format(str(dim))) - tvm.testing.assert_allclose(tflite_output[i], - np.transpose(tvm_output[i], axes=axes), - atol=1e-5, rtol=1e-5) - else: - tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], - atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) sess.close() @@ -173,14 +155,12 @@ def _test_pooling_iteration(input_shape, **kwargs): x = -np.arange( np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 - tvm_data = np.transpose(x, axes=(0, 3, 1, 2)) with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=input_shape, dtype='float32') out = nn_ops.pool(in_data, **kwargs) - compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out], - output_need_transpose=True) + compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out]) def _test_pooling(input_shape, **kwargs): @@ -258,13 +238,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, strides=strides, padding=padding, data_format=data_format) - # TFLite is NHWC, TVM is NCHW - tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') - tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2)) - # TFLite output is NHWC, TVM is NCHW, we need transpose - compare_tflite_with_tvm(tflite_data_array, tvm_data_array, - 'Placeholder:0', [in_data], [out], - output_need_transpose=True) + data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') + compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out]) def test_forward_convolution(): @@ -286,22 +261,11 @@ def test_forward_convolution(): def _test_reshape(data, out_shape): """ One iteration of reshape operation with given data and out shape """ - # see relay/frontend/tflite.py convert_reshape more detail of channel first rule - if len(data.shape) == 1 or len(data.shape) == 2: - tvm_data = data - elif len(data.shape) == 3: - tvm_data = np.transpose(data, axes=(0, 2, 1)) - elif len(data.shape) == 4: - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) - else: - raise NotImplementedError("Not support input shape {} of reshape : ". - format(str(len(data)))) - with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = array_ops.reshape(in_data, out_shape) - compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_reshape(): @@ -319,18 +283,6 @@ def _test_concatenation(data, axis): """ One iteration of concatenation """ assert len(data) >= 1 - need_transpose = False - if len(data[0].shape) == 1 or len(data[0].shape) == 2: - tvm_data = data - elif len(data[0].shape) == 3: - #need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data] - elif len(data[0].shape) == 4: - need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data] - else: - raise NotImplementedError("Not support input shape {} of reshape : ". - format(str(len(data)))) with tf.Graph().as_default(): in_data = [ @@ -339,7 +291,7 @@ def _test_concatenation(data, axis): out = array_ops.concat(in_data, axis=axis) name = ["in_{}:0".format(idx) for idx in range(len(data))] - compare_tflite_with_tvm(data, tvm_data, name, in_data, [out], need_transpose) + compare_tflite_with_tvm(data, name, in_data, [out]) def test_forward_concatenation(): @@ -366,33 +318,19 @@ def _test_add(data): """ One iteration of add """ assert len(data) == 2 - need_transpose = False - if len(data[0].shape) == 1 or len(data[0].shape) == 2: - tvm_data = data - elif len(data[0].shape) == 3: - need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data] - elif len(data[0].shape) == 4: - need_transpose = True - tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data] - else: - raise NotImplementedError("Not support input shape {} of add : ". - format(str(len(data.shape)))) # Test with two tensors with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] out = math_ops.add(in_data[0], in_data[1]) - compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'], - in_data, [out], need_transpose) + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) # Test with tensor and constant with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) - compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'], - in_data, [out], need_transpose) + compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) def test_forward_add(): @@ -415,19 +353,6 @@ def _test_squeeze(data, squeeze_dims=None): if squeeze_dims is None: squeeze_dims = [] - # see relay/frontend/tflite.py convert_squeeze more detail of channel first rule - if len(data.shape) == 1 or len(data.shape) == 2: - tvm_data = data - elif len(data.shape) == 3: - tvm_data = np.transpose(data, axes=(0, 2, 1)) - elif len(data.shape) == 4: - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) - else: - raise NotImplementedError("Not support input shape {} of reshape : ". - format(str(len(data.shape)))) - - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) - with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) @@ -436,7 +361,7 @@ def _test_squeeze(data, squeeze_dims=None): else: out = array_ops.squeeze(in_data) - compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_squeeze(): @@ -453,7 +378,7 @@ def _test_softmax(data): with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = nn_ops.softmax(in_data) - compare_tflite_with_tvm(data, data, 'Placeholder:0', [in_data], [out]) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_softmax(): """ Softmax """ @@ -496,10 +421,8 @@ def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None): in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32') out = nn_ops.bias_add(out, in_bias) - tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') - tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2)) - compare_tflite_with_tvm(tflite_data_array, tvm_data_array, - 'Placeholder:0', [in_data], [out]) + data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') + compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out]) def test_forward_fully_connected(): @@ -523,9 +446,8 @@ def test_forward_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) @@ -538,9 +460,8 @@ def test_forward_mobilenet_v2(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) @@ -557,9 +478,8 @@ def test_forward_inception_v3_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) @@ -572,9 +492,8 @@ def test_forward_inception_v4_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') - tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index 67edeb8a38de..f8686e9d20ab 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -117,32 +117,23 @@ def extract(path): plt.show() image_data = np.asarray(resized_image).astype("float32") -# convert HWC to CHW -image_data = image_data.transpose((2, 0, 1)) - -# after expand_dims, we have format NCHW +# after expand_dims, we have format NHWC image_data = np.expand_dims(image_data, axis=0) # preprocess image as described here: # https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243 -image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1 -image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1 -image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1 +image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1 +image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1 +image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1 print('input', image_data.shape) -#################################################################### -# -# .. note:: Input layout: -# -# Currently, TVM TFLite frontend accepts ``NCHW`` as input layout. - ###################################################################### # Compile the model with relay # --------------------------------------------- # TFLite input tensor name, shape and type input_tensor = "input" -input_shape = (1, 3, 224, 224) +input_shape = (1, 224, 224, 3) input_dtype = "float32" # parse TFLite model and convert into Relay computation graph From ea40f53fdd2eea9af105d305697272e3af3392b5 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 21 May 2019 22:28:49 -0700 Subject: [PATCH 030/176] [Team] Eddie -> PMC (#3220) --- CONTRIBUTORS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 7e0ad806c7ec..5b5c6b745efb 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -55,7 +55,7 @@ We do encourage everyone to work anything they are interested in. - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Leyuan Wang](https://github.com/Laurawly): @Laurawly: - topi - [Yao Wang](https://github.com/kevinthesun): @kevinthesun: - topi, vision -- [Eddie Yan](https://github.com/eqy): @eqy - runtime, autotvm, rpc, topi +- [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, topi, relay ## Reviewers From 3602f8cf12581c77a2f31cfc33174715b2871713 Mon Sep 17 00:00:00 2001 From: llyfacebook <34827865+llyfacebook@users.noreply.github.com> Date: Wed, 22 May 2019 09:11:17 -0700 Subject: [PATCH 031/176] Add packing for int8 1x1 convolution and support the int8 group convolution on X86 (#2991) * Support the 1x1 int8 conv with NHWC layout and weight packing fix linter * fix the memoize issue * fix the failed nhwc test * add the schedule for pack to unbreak other tests * skip avx512 compile * Support the 1x1 int8 conv with NHWC layout and weight packing fix linter * fix the memoize issue * fix the failed nhwc test * add the schedule for pack to unbreak other tests * skip avx512 compile * Unify the data_layout and kernel_layout relation * add asf header * fix the comment * retrigger the build/test --- topi/python/topi/generic/nn.py | 18 +++ topi/python/topi/nn/conv2d.py | 25 +++- topi/python/topi/x86/conv2d.py | 101 ++++++++++++++-- topi/python/topi/x86/conv2d_avx_1x1.py | 106 ++++++++++++++++- .../python/test_topi_conv2d_nhwc_pack_int8.py | 90 ++++++++++++++ .../test_topi_group_conv2d_NCHWc_int8.py | 111 ++++++++++++++++++ 6 files changed, 436 insertions(+), 15 deletions(-) create mode 100644 topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py create mode 100644 topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index db1c772279e5..60a2d55486e5 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -52,6 +52,24 @@ def schedule_conv2d_nchw(outs): return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_conv2d_nhwc_pack(outs): + """Schedule for conv2d_nhwc_pack + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_nhwc_pack + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_conv2d_nhwc(outs): """Schedule for conv2d_nhwc diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 06d4074147c1..83e0274597d7 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -28,8 +28,8 @@ # workload description of conv2d Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', + 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @tvm.target.generic_func def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None): @@ -95,11 +95,24 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): return None -def _get_workload(data, kernel, stride, padding, out_dtype): +def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ - _, CI, IH, IW = [x.value for x in data.shape] - CO, _, KH, KW = [x.value for x in kernel.shape] + if data_layout == 'NCHW': + _, CI, IH, IW = [x.value for x in data.shape] + elif data_layout == 'NHWC': + _, IH, IW, CI = [x.value for x in data.shape] + elif data_layout == 'HWCN': + IH, IW, CI, _ = [x.value for x in data.shape] + else: + raise ValueError("not support this layout {} yet".format(data_layout)) + + if data_layout == 'NCHW': + CO, CIG, KH, KW = [x.value for x in kernel.shape] + else: + KH, KW, CO, CIG = [x.value for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: @@ -107,7 +120,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype): assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ "Do not support inputs with different data types now. ' \ '{} vs. {}".format(data.dtype, kernel.dtype) - return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 4d4b3fef4826..c333892a9918 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -37,7 +37,8 @@ logger = logging.getLogger('topi') -def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, + layout='NCHW'): """ Get default schedule config for the workload """ @@ -46,7 +47,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth from .depthwise_conv2d import _fallback_schedule _fallback_schedule(cfg, wkl) else: - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: conv2d_avx_1x1._fallback_schedule(cfg, wkl) @@ -62,6 +63,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): if layout == 'NCHW': n, ic, h, w = dshape oc, _, kh, kw = kshape + elif layout == 'NHWC': + n, h, w, ic = dshape + kh, kw, oc, _ = kshape elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape if data.dtype == 'uint8': @@ -93,21 +97,31 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): cfg.define_knob("unroll_kw", [True, False]) -@autotvm.register_topi_compute(conv2d, 'cpu', 'direct') +@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + if layout == 'NCHW': _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) if cfg.is_fallback: _get_default_config(cfg, data, kernel, strides, padding, out_dtype) return _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + + # HWOI kernel layout is for NHWC and HWCN + kh, kw, _, _ = get_const_tuple(kernel.shape) if layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) - if layout == 'NHWC': + elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8": + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout) + # specialize for INT8 1X1 conv on X86 + return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, + padding, dilation, out_dtype) + elif layout == 'NHWC': return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) raise ValueError("not support this layout {} yet".format(layout)) @@ -226,6 +240,58 @@ def traverse(op): return s +@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct']) +def schedule_conv2d_nhwc_pack(cfg, outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 4: # schedule bias + bn + relu + n, h, w, c = op.axis + fused = s[op].fuse(n, h, w) + s[op].parallel(fused) + s[op].vectorize(c) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv2d_nhwc_pack_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + if data.dtype == 'uint8': + # int8 conv kernel is 7-dim + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) + else: + raise ValueError("Only support 1x1 kernel with " + "schedule_conv2d_nhwc_pack.") + else: + raise ValueError("Not support this data type {} with " + "schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype)) + + scheduled_ops.append(op) + traverse(output_op) + return s + + @generic.schedule_conv2d_nhwc.register("cpu") def schedule_conv2d_nhwc(outs): """Create schedule for tensors""" @@ -427,10 +493,13 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ + get_const_tuple(kernel.shape) else: - oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ + get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn + groups = ic_chunk // ic_chunk_group if cfg.is_fallback: _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), @@ -454,7 +523,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') - if data.dtype == 'uint8': + if data.dtype == 'uint8' and groups == 1: assert out_dtype == "int32", \ "INT8 convolution requires input dtype = uint8 and output dtype=int32" # Intel performs dot product of 2 "4" Int8 values @@ -473,6 +542,24 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + if data.dtype == 'uint8': + # for int8 group conv support + n_elems = 4 + ic_chunk = in_channel//ic_bn + ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: + tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\ + (ic_chunk//groups)+ic_outer, + oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * + kernel[occ, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index bcd2cefc2bdf..4994d4580ab5 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -20,8 +20,9 @@ import tvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from ..nn.util import infer_pad -from ..util import get_const_tuple +from ..nn.pad import pad +from ..nn.util import infer_pad, get_pad_tuple +from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake from .util import get_fp32_len @@ -251,3 +252,104 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): s[O].parallel(parallel_axis) return s + + +def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): + # more assertion for the shapes + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape + kernel_h, kernel_w, num_filter, channel = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + # todo: padding filter to accomodate the intrinsic + + # packing the Filter to let memory access be consecutive for AVX512 intrinsic + # Done in pre-compute stage + packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) + PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], + name="packed_filter") + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: tvm.sum( + PaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") + return Output + + +def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): + """ + Defines the schedule for the int8 nhwc layout. For 1x1 conv, it + is a matrix-multiply operation by using nhwc layout. We will do + packing of weight to make the address access be friendly to int8 + intrinsic + """ + target = tvm.target.current_target(allow_none=False) + int32_lanes = -1 + if check_skylake(target): + int32_lanes = 16 + else: + return s + assert int32_lanes != -1 + + # assertion to fail the unhandled case + _, _, _, ic_num = get_const_tuple(data.shape) + _, _, _, oc_num = get_const_tuple(conv_out.shape) + assert ic_num % 4 == 0 + assert oc_num % 16 == 0 + + ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ih, iw, ic = s[A].op.axis + d_ic_chunk, d_ic_block = s[A].split(ic, factor=4) + s[A].vectorize(d_ic_block) + + C, O = conv_out, last + + batch, oh, ow, oc = s[C].op.axis + kh, kw, ic = s[C].op.reduce_axis + # match the x86 intrinsic + ic_outer, ic_inner = s[C].split(ic, factor=4) + oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes) + + ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor) + s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner) + + pc = dot_16x1x16_int8_int8_int32() + s[C].tensorize(oc_inner, pc) + + if C != O: + batch, last_oh, last_ow, last_oc = s[O].op.axis + oc_chunk, oc_block = s[O].split(ochannel, 16) + # not saw perf improvement to split oh/ow here + s[O].vectorize(oc_block) + + return s diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py new file mode 100644 index 000000000000..763150ac425f --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + adtype = A.dtype + wdtype = W.dtype + + @memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(adtype) + w_np = np.random.uniform(size=w_shape).astype(wdtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc_pack([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + # for device in ['llvm -mcpu=skylake-avx512']: + for device in ['llvm']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_nhwc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0) + + +if __name__ == "__main__": + test_conv2d_nhwc() diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py new file mode 100644 index 000000000000..6ed1b4aabc16 --- /dev/null +++ b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test for NCHW[x]c convolution""" + +import numpy as np +import tvm +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]i[x]o + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn//4, kh, kw, 4)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1, 6)) + return kernel + +def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, dtype="int32"): + assert dilation == 1, "conv2d_NCHWc does not support dilation for now." + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding)) + + in_height = in_width = in_size + + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 8 + autotvm.DispatchContext.current.silent = True + A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8') + W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8') + + @memoize("topi.tests.test_topi_conv2d_NCHWc_int8.verify_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype("uint8") + w_np = np.random.uniform(size=(num_filter, in_channel//groups, kernel, kernel)).astype("int8") + c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding, groups) + return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \ + _transform_data(c_np, oc_block) + + a_np, w_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + (dilation, dilation), + layout='NCHW%dc'%ic_block, + out_layout="NCHW%dc"%oc_block, + out_dtype=dtype) + s = topi.generic.schedule_conv2d_NCHWc([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + # print(tvm.lower(s, [A, W, C], simple_mode=True)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) + + # for device in ["llvm -mcpu=skylake-avx512"]: + for device in ["llvm"]: + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +def test_conv2d_NCHWc(): + # ResNet50 workloads + verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3) + +if __name__ == "__main__": + test_conv2d_NCHWc() From b950767a121c2c222b79edc830dd7dcd662d8f7c Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Wed, 22 May 2019 11:09:01 -0700 Subject: [PATCH 032/176] [Bugfix] Fix sort changing original input data issue (#3212) * sort bugfix for not rearranging input data * separate sort schedule * fix lint * use identity op instead * fix lint * remove redundent code --- src/op/extern_op.cc | 5 ++- topi/python/topi/cuda/nms.py | 3 +- topi/python/topi/cuda/sort.py | 47 +++++++++++++++++--- topi/python/topi/cuda/vision.py | 77 +++------------------------------ 4 files changed, 53 insertions(+), 79 deletions(-) diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index e6c6039b610e..7023aebe17ad 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name, CHECK_EQ(inputs.size(), input_placeholders.size()); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); - CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape)); + CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); + for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { + CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + } CHECK_EQ(input_placeholders[i]->strides.size(), 0U); } n->inputs = std::move(inputs); diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 0c27bd216999..925cf24acd11 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -24,6 +24,7 @@ from tvm.intrin import if_then_else, log, power from topi.vision import non_max_suppression, get_valid_counts from .sort import argsort +from .. import tag def get_valid_counts_pre(data, flag, idx, score_threshold): @@ -730,7 +731,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, "valid_count_buf", data_alignment=4) score_axis = score_index score_shape = (batch_size, num_anchors) - score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) + score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 99ba8527cdfb..678d494dae50 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -20,6 +20,10 @@ from tvm import api from topi.sort import argsort +from topi.math import identity +from .. import generic +from .. import tag + def sort_ir(data, output, axis, is_ascend): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. @@ -104,8 +108,6 @@ def sort_ir(data, output, axis, is_ascend): return ib.get() - - def sort_nms_ir(data, valid_count, output, axis, is_ascend): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. @@ -221,29 +223,60 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 out : tvm.Tensor The output of this function. """ - data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) + sorted_data = identity(data) if flag: valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out = tvm.extern([data.shape], - [data, valid_count], + [sorted_data, valid_count], lambda ins, outs: sort_nms_ir( ins[0], ins[1], outs[0], axis, is_ascend), dtype="int32", - in_buffers=[data_buf, valid_count_buf], + in_buffers=[sorted_data_buf, valid_count_buf], out_buffers=[out_buf], name="argsort_nms_gpu", tag="argsort_nms_gpu") else: out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = tvm.extern([data.shape], - [data], + [sorted_data], lambda ins, outs: sort_ir( ins[0], outs[0], axis, is_ascend), dtype=dtype, - in_buffers=[data_buf], + in_buffers=[sorted_data_buf], out_buffers=[out_buf], name="argsort_gpu", tag="argsort_gpu") return out + +@generic.schedule_argsort.register(["cuda", "gpu"]) +def schedule_argsort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argsort + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + from .injective import _schedule_injective + def traverse(op): + if tag.is_broadcast(op.tag): + _schedule_injective(op, s) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + traverse(outs[0].op) + + return s diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 78f5c1f51ec6..968e554ac81d 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -25,41 +25,17 @@ def _default_schedule(outs): """Default schedule for gpu.""" - target = tvm.target.current_target() outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] - + from .injective import _schedule_injective def traverse(op): - """inline all one-to-one-mapping operators except the last stage (output)""" - if op.tag in ["nms", "invalid_to_bottom"]: - if op.tag == "nms": - sort = op.input_tensors[1] - else: - out = op.input_tensors[0] - sort = s[out].op.input_tensors[1] - score = s[sort].op.input_tensors[0] - fused = s[score].fuse(*s[score].op.axis) - num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads) - bx, tx = s[score].split(fused, factor=num_thread) - s[score].bind(bx, tvm.thread_axis("blockIdx.x")) - s[score].bind(tx, tvm.thread_axis("threadIdx.x")) - if tag.is_broadcast(op.tag): - if op not in s.outputs: - s[op].compute_inline() - else: - x = op.output(0) - fused = s[x].fuse(*s[x].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads - bx, tx = s[x].split(fused, factor=num_thread) - s[x].bind(bx, tvm.thread_axis("blockIdx.x")) - s[x].bind(tx, tvm.thread_axis("threadIdx.x")) - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - + if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']: + _schedule_injective(op, s) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) scheduled_ops.append(op) - traverse(outs[0].op) return s @@ -173,19 +149,7 @@ def schedule_proposal(outs): s: Schedule The computation schedule for the op. """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - from .injective import _schedule_injective - def traverse(op): - if op.tag in ['bbox_score', 'sorted_bbox']: - _schedule_injective(op, s) - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - traverse(outs[0].op) - return s + return _default_schedule(outs) @generic.schedule_get_valid_counts.register(["cuda", "gpu"]) def schedule_get_valid_counts(outs): @@ -203,30 +167,3 @@ def schedule_get_valid_counts(outs): The computation schedule for the op. """ return _default_schedule(outs) - -@generic.schedule_argsort.register(["cuda", "gpu"]) -def schedule_argsort(outs): - """Schedule for argsort operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of argsort - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for the op. - """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - from .injective import _schedule_injective - def traverse(op): - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - traverse(outs[0].op) - return s From 6d1520a2556402947e168844c80e6cad7038a90a Mon Sep 17 00:00:00 2001 From: Yuta Hinokuma Date: Thu, 23 May 2019 03:48:39 +0900 Subject: [PATCH 033/176] [WIP] [Relay] [NNVM] [Frontend] implement MaxPool-8 and MaxPool-10 (#3114) --- nnvm/python/nnvm/frontend/onnx.py | 37 ++++++++++++++++++++++++++ python/tvm/relay/frontend/onnx.py | 43 +++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index 2434fb01c1d5..c8b050ad2343 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -27,6 +27,13 @@ __all__ = ['from_onnx'] +def onnx_storage_order2layout(storage_order): + if storage_order not in (0, 1): + raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1') + + return 'NCHW' if sotrage_order == 0 else 'NHWC' + + class OnnxOpConverter(object): """ A helper class for holding onnx op converters. """ @@ -207,8 +214,38 @@ def _impl_v1(cls, inputs, attr, params): class MaxPool(Pool): + """ Operator converter for MaxPool + """ name = 'max_pool' + @classmethod + def _impl_v8(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + # TODO(higumachan): make sure ceil_mode in onnx, and layout? + extras={'ceil_mode': False}, + custom_check=dimension_constraint())(inputs, attr, params) + + @classmethod + def _impl_v10(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + 'ceil_mode': 'ceil_mode' + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + custom_check=dimension_constraint())(inputs, attr, params) class Mul(Elemwise): name = 'mul' diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c70f5aba39fe..18253e498560 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -52,6 +52,15 @@ def revert_caffe2_pad(pads): 'Number of pads must be either 2 or 4.') return pads + +def onnx_storage_order2layout(storage_order): + """converter of onnx storage order parameter to tvm storage order format""" + if storage_order not in (0, 1): + raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1') + + return 'NCHW' if sotrage_order == 0 else 'NHWC' + + def dimension_constraint(): def _dim_check(attrs): if len(attrs['kernel_shape']) == 2: @@ -60,6 +69,7 @@ def _dim_check(attrs): return _dim_check, "Only 2d kernel supported." + class OnnxOpConverter(object): """ A helper class for holding onnx op converters. """ @@ -108,6 +118,7 @@ def _impl_v1(cls, inputs, attr, params): inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2) return get_relay_op(op_name)(*inputs) + class Pool(OnnxOpConverter): """ A helper class for pool op converters. """ @@ -247,6 +258,7 @@ def _impl_v1(cls, inputs, attr, params): inputs[1], units=channels) return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) + class MatMul(OnnxOpConverter): """ Operator converter for MatMul. """ @@ -257,9 +269,40 @@ def _impl_v1(cls, inputs, attr, params): input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) + class MaxPool(Pool): + """ Operator converter for MaxPool + """ name = 'max_pool' + @classmethod + def _impl_v8(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + # TODO(higumachan): make sure ceil_mode in onnx, and layout? + extras={'ceil_mode': False}, + custom_check=dimension_constraint())(inputs, attr, params) + + @classmethod + def _impl_v10(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + 'ceil_mode': 'ceil_mode' + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + custom_check=dimension_constraint())(inputs, attr, params) class Mul(Elemwise): name = 'multiply' From d7f2d30f2a1b200753f72b56ce514db2a019df23 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Wed, 22 May 2019 13:52:52 -0700 Subject: [PATCH 034/176] [relay][pass manager] Open transform namespace (#3226) --- include/tvm/relay/pass.h | 204 +-------------- include/tvm/relay/transform.h | 243 ++++++++++++++++++ python/tvm/relay/__init__.py | 18 +- python/tvm/relay/_ir_pass.pyi | 54 ---- python/tvm/relay/_transform.py | 21 ++ python/tvm/relay/ir_pass.py | 312 +---------------------- python/tvm/relay/transform.py | 325 ++++++++++++++++++++++++ python/tvm/relay/transform.pyi | 71 ++++++ src/relay/pass/pass_manager.cc | 81 +++--- tests/python/relay/test_pass_manager.py | 45 ++-- 10 files changed, 728 insertions(+), 646 deletions(-) create mode 100644 include/tvm/relay/transform.h create mode 100644 python/tvm/relay/_transform.py create mode 100644 python/tvm/relay/transform.py create mode 100644 python/tvm/relay/transform.pyi diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 31067925fa63..c84e3f952de4 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -20,46 +20,12 @@ /*! * \file tvm/relay/pass.h * \brief The set of Relay passes written in C++. - * - * This file also implements a pass manager. The pass manager manages a sequence - * of Relay-to-Relay transformation passes over a particlar unit of AST. The - * design is largely inspired from LLVM's pass manager and modern deep learning - * frameworks that perform tensor->tensor transformations. - * - * The responsibilities of a traditional compiler pass manager usually involves: - * - Organizing the execution order of optimization passes though not - * necessarily in the optimal sequence. - * - Collecting required analysis information and keep them up-to-date. - * - Reducing the effort required to implement new passes for compiler - * developers, etc. - * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work - * different granularity, i.e. module level, function level, and even sequential - * passe that contains a host of passes. - * - * However, we also extend the functionality of the traditional pass manager - * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass - * manager performs the Relay.Module -> Relay.Module transformation. All - * different types of passes, including the sequential-level pass object, are - * essentially pass objects. This design, therefore, effectively provides users - * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with - * the pass manager, external users will be able to have custom passes correctly - * scheduled without having to modify a single handcrafted pass order. - * - * In the future we need to describe constraints between passes. For example, - * we may want to preserve dependencies between different passes and validate - * them on the completion of a certain pass. - * - * We also need to store side information and import the error reporting system. - */ + */ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ #include #include -#include #include #include #include @@ -72,174 +38,6 @@ namespace tvm { namespace relay { -namespace pass { - -/* - * \brief The context of pass. - */ -class PassContext; - -/*! - * \brief PassContextNode contains the information that a pass can rely on, such as - * analysis results. - */ -class PassContextNode : public RelayNode { - public: - /*! - * \brief The error reporter used to notify users why an optimization fails. - */ - ErrorReporter err_reporter; - - PassContextNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) final { - } - - TVM_DLL static PassContext make(); - - static constexpr const char* _type_key = "relay.PassContext"; - TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); -}; - -TVM_DEFINE_NODE_REF(PassContext, PassContextNode) - -/* - * \brief The meta data of a pass. - * - * PassInfo can be extended conveniently in the future if more meta information - * is needed. - */ -class PassInfo; - -/*! - * \brief PassInfoNode contains meta data that will be used to help optimization - * and analysis. - */ -class PassInfoNode : public RelayNode { - public: - /*! \brief The minimal optimization level that this pass will be enabled. */ - int opt_level; - - /*! \brief The name of an optimization/analysis pass. */ - std::string name; - - /*! \brief The passes that are required to perform the current pass. */ - tvm::Array required; - - PassInfoNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("opt_level", &opt_level); - v->Visit("name", &name); - v->Visit("required", &required); - } - - TVM_DLL static PassInfo make(int opt_level, std::string name, - tvm::Array required); - - static constexpr const char* _type_key = "relay.PassInfo"; - TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); -}; - -TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) - -class Pass; - -/*! - * \brief PassNode is the base type of differnt types of optimization passes. - * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. - */ -class PassNode : public RelayNode { - public: - /* - * \brief Get the pass information/meta data. */ - virtual PassInfo Info() const = 0; - - /*! - * \brief Set the context information for a pass. - * - * \param pass_ctx The context information for a certain pass. - */ - virtual void SetContext(const PassContext& pass_ctx) = 0; - - /*! - * \brief Execute the optimization pass using a functor. - * - * \param mod The module that an optimization pass runs on. - * - * \return The updated module. - */ - virtual Module operator()(const Module& mod) const = 0; - - void VisitAttrs(tvm::AttrVisitor* v) override {} - - static constexpr const char* _type_key = "relay.Pass"; - TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); -}; - -class Pass : public NodeRef { - public: - Pass() = default; - explicit Pass(NodePtr p) : NodeRef(p) {} - - PassNode* operator->() const { - return static_cast(this->node_.get()); - } - - using ContainerType = PassNode; -}; - -/* - * \brief Create a module pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the module pass. - * \param name The name of the module pass. - * \param required The list of the passes that the module pass is dependent on. - * - * \return The created module pass. - */ -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); - -/* - * \brief Create a function pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the function pass. - * \param name The name of the function pass. - * \param required The list of the passes that the function pass is dependent on. - * - * \return The created function pass. - */ -Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); -/* - * \brief Create a sequential pass. - * - * \param passes The optimization passes will be performed. - * \param opt_level The optimization level of the sequential pass. - * \param name The name of the sequential pass. - * \param required The list of the passes that the sequential pass is dependent on. - * \param disabled The disabled passes. - * - * \return The created sequential pass. - */ -Pass CreateSequentialPass(const tvm::Array& passes, - int opt_level, - const std::string& name, - const tvm::Array& required, - const tvm::Array& disabled); - -} // namespace pass - /*! * \brief Infer the type of an expression. * diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h new file mode 100644 index 000000000000..ba25483dfbb2 --- /dev/null +++ b/include/tvm/relay/transform.h @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/transform.h + * + * This file implements a pass manager. The pass manager manages a sequence + * of Relay-to-Relay transformation passes over a particlar unit of AST. The + * design is largely inspired from LLVM's pass manager and modern deep learning + * frameworks that perform tensor->tensor transformations. + * + * The responsibilities of a traditional compiler pass manager usually involves: + * - Organizing the execution order of optimization passes though not + * necessarily in the optimal sequence. + * - Collecting required analysis information and keep them up-to-date. + * - Reducing the effort required to implement new passes for compiler + * developers, etc. + * + * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * different granularity, i.e. module level, function level, and even sequential + * passe that contains a host of passes. + * + * However, we also extend the functionality of the traditional pass manager + * with the consideration of requirements/convention from deep learning + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * manager performs the Relay.Module -> Relay.Module transformation. All + * different types of passes, including the sequential-level pass object, are + * essentially pass objects. This design, therefore, effectively provides users + * a consistent and convenient interface, i.e. Pass, to play with. It offers a + * means to ease the development and testing of Relay passes. For example, with + * the pass manager, external users will be able to have custom passes correctly + * scheduled without having to modify a single handcrafted pass order. + * + * In the future we need to describe constraints between passes. For example, + * we may want to preserve dependencies between different passes and validate + * them on the completion of a certain pass. + * + * We also need to store side information and import the error reporting system. + */ +#ifndef TVM_RELAY_TRANSFORM_H_ +#define TVM_RELAY_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +/* + * \brief The context of pass. + */ +class PassContext; + +/*! + * \brief PassContextNode contains the information that a pass can rely on, such as + * analysis results. + */ +class PassContextNode : public RelayNode { + public: + /*! + * \brief The error reporter used to notify users why an optimization fails. + */ + ErrorReporter err_reporter; + + PassContextNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) final { + } + + TVM_DLL static PassContext make(); + + static constexpr const char* _type_key = "relay.PassContext"; + TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); +}; + +TVM_DEFINE_NODE_REF(PassContext, PassContextNode) + +/* + * \brief The meta data of a pass. + * + * PassInfo can be extended conveniently in the future if more meta information + * is needed. + */ +class PassInfo; + +/*! + * \brief PassInfoNode contains meta data that will be used to help optimization + * and analysis. + */ +class PassInfoNode : public RelayNode { + public: + /*! \brief The minimal optimization level that this pass will be enabled. */ + int opt_level; + + /*! \brief The name of an optimization/analysis pass. */ + std::string name; + + /*! \brief The passes that are required to perform the current pass. */ + tvm::Array required; + + PassInfoNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("name", &name); + v->Visit("required", &required); + } + + TVM_DLL static PassInfo make(int opt_level, std::string name, + tvm::Array required); + + static constexpr const char* _type_key = "relay.PassInfo"; + TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); +}; + +TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) + +class Pass; + +/*! + * \brief PassNode is the base type of differnt types of optimization passes. + * It is designed as a pure class and implemented by different pass subclasses + * at different granularity of Relay nodes. + */ +class PassNode : public RelayNode { + public: + /* + * \brief Get the pass information/meta data. */ + virtual PassInfo Info() const = 0; + + /*! + * \brief Set the context information for a pass. + * + * \param pass_ctx The context information for a certain pass. + */ + virtual void SetContext(const PassContext& pass_ctx) = 0; + + /*! + * \brief Execute the optimization pass using a functor. + * + * \param mod The module that an optimization pass runs on. + * + * \return The updated module. + */ + virtual Module operator()(const Module& mod) const = 0; + + void VisitAttrs(tvm::AttrVisitor* v) override {} + + static constexpr const char* _type_key = "relay.Pass"; + TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); +}; + +class Pass : public NodeRef { + public: + Pass() = default; + explicit Pass(NodePtr p) : NodeRef(p) {} + + PassNode* operator->() const { + return static_cast(this->node_.get()); + } + + using ContainerType = PassNode; +}; + +class SequentialNode; + +class Sequential : public Pass { + public: + /*! + * \brief The constructor of `Sequential`. + * \param passes The passes to apply. + * \param pass_info The pass metadata. + * \param disabled The passes that will not be applied. + */ + TVM_DLL Sequential(tvm::Array passes, + PassInfo pass_info, + tvm::Array disabled); + Sequential() = default; + explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} + + const SequentialNode* operator->() const; + using ContainerType = Sequential; +}; + + +/* + * \brief Create a module pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the module pass. + * \param name The name of the module pass. + * \param required The list of the passes that the module pass is dependent on. + * + * \return The created module pass. + */ +Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +/* + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORM_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 1f1e4a683ead..d832c8988795 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -25,6 +25,7 @@ from . import module from . import adt from . import ir_pass +from . import transform from .build_module import build, build_config, create_executor from . import prelude from . import parser @@ -97,9 +98,8 @@ var = expr.var const = expr.const bind = expr.bind -module_pass = ir_pass.module_pass -function_pass = ir_pass.function_pass -sequential_pass = ir_pass.sequential_pass +module_pass = transform.module_pass +function_pass = transform.function_pass # ExprFunctor ExprFunctor = expr_functor.ExprFunctor @@ -114,9 +114,9 @@ load_param_dict = param_dict.load_param_dict # Pass manager -PassInfo = ir_pass.PassInfo -PassContext = ir_pass.PassContext -Pass = ir_pass.Pass -ModulePass = ir_pass.ModulePass -FunctionPass = ir_pass.FunctionPass -SequentialPass = ir_pass.SequentialPass +PassInfo = transform.PassInfo +PassContext = transform.PassContext +Pass = transform.Pass +ModulePass = transform.ModulePass +FunctionPass = transform.FunctionPass +Sequential = transform.Sequential diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi index 6aedb5248657..13035bb36f71 100644 --- a/python/tvm/relay/_ir_pass.pyi +++ b/python/tvm/relay/_ir_pass.pyi @@ -17,62 +17,8 @@ import tvm from . import ir -from .base import NodeBase from .env import Module - -class PassContext(NodeBase): - def __init__(self): - ... - -class PassInfo(NodeBase): - name = ... # type: str - opt_level = ... # type: int - required = ... # type: list - - def __init__(self, name, opt_level, required) - # type: (str, int, list) -> None - - -class Pass(NodeBase): - def __init__(self): - ... - - -class ModulePass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class FunctionPass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class SequentialPass(Pass): - name = ... # type: str - opt_level = ... # type: int - passes = ... # type: list - required = ... # type: list - disabled = ... # type: list - - def __init__(self, name, opt_level, passes, required, disabled): - # type: (str, int, list, list, list) -> None - ... - - def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/_transform.py new file mode 100644 index 000000000000..273d97e0962a --- /dev/null +++ b/python/tvm/relay/_transform.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._transform", __name__) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 5f23e14d5559..ea34c6b1958b 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -17,324 +17,16 @@ # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck """ -This file contains: -1. The set of passes for Relay, which exposes an interface for configuring the - passes and scripting them in Python. - -2. The pass manager for Relay which exposes different granularity of interfaces - for users to implement and use passes more conveniently. +This file contains the set of passes for Relay, which exposes an interface for +configuring the passes and scripting them in Python. """ -import types - from . import _ir_pass from . import _make from .expr import Expr from .ty import Type -from .base import RelayNode, register_relay_node from .module import Module -@register_relay_node -class PassInfo(RelayNode): - """The class that contains the meta data required by a pass. It is the - container of information needed by running an optimization or analysis. - This class can be extended by adding new members when more meta data is - needed. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - - required : List[str] - The list of passes that are required by a certain pass. - """ - - def __init__(self, name, opt_level, required=None): - self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level, - required) - - -@register_relay_node -class PassContext(RelayNode): - """The basis where a Relay optimization/analysis runs on. - Each pass context contains a number of auxiliary information that is used - to help an optimization pass. Such information includes the error reporter - to record the errors of during the optimization, etc. - """ - - def __init__(self): - self.__init_handle_by_constructor__(_ir_pass.PassContext) - - -@register_relay_node -class Pass(RelayNode): - """The base class of all passes. All methods here are just simple wrappers - that are implemented in the backend. They are defined for users to - conveniently interact with the base class. - """ - - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _ir_pass.SetContext(self, pass_ctx) - - @property - def info(self): - """Get the pass meta.""" - return _ir_pass.Info(self) - - def __call__(self, mod): - """Execute the pass. Note that for sequential pass, the dependency among - different passes will be resolved in the backend. - - Parameters - ---------- - mod : tvm.relay.Module - The module that a certain optimization is performed on. - - Returns - ------- - mod : tvm.relay.Module - The updated module after applying this pass. - """ - return _ir_pass.RunPass(self, mod) - - -@register_relay_node -class ModulePass(Pass): - """A pass that works on tvm.relay.Module. Users don't need to interact with - this class directly. Instead, a module pass should be created through - `module_pass`, because the design of the `module_pass` API is flexible - enough to handle the creation of a module pass in different manners. In - addition, all members of a module pass can be accessed from the base class. - The same rule applies to FunctionPass and SequentialPass as well. - """ - - -@register_relay_node -class FunctionPass(Pass): - """A pass that works on each tvm.relay.Function in a module. A function - pass class should be created through `function_pass`. - """ - - -@register_relay_node -class SequentialPass(Pass): - """A pass that works on a sequence of pass objects. A sequential pass class - should be created through `sequential_pass`. - """ - - -def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a module pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created module level pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the module pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_module_pass : Union[Callable, ModulePass] - The callable that will create a module pass is returned when - pass_func is not passed in. Otherwise, a ModulePass object will be - directly created. - - Examples - -------- - The following code creates a module level pass and adds an abs function to - the module. - - .. code-block:: python - - @relay.ir_pass.module_pass(opt_level=2) - def transform(mod, ctx): - tp = relay.TensorType((10,), "float32") - x = relay.var("x", tp) - gv = relay.GlobalVar("var") - func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) - new_mod.update(mod) - return new_mod - - module_pass = transform - assert isinstance(module_pass, ir_pass.ModulePass) - assert module_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = module_pass(m) - # Now a function abs should be added to the module m. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the module pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_module_pass(pass_func): - """Internal function that creates a module pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _ir_pass.CreateModulePass(pass_func, opt_level, - name if name else pass_func.__name__, - required) - - if pass_func: - return create_module_pass(pass_func) - return create_module_pass - - -def function_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a function pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created function pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the function pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_function_pass : Union[Callable, FunctionPass] - The callable that will create a function pass is returned when - pass_func is not passed in. Otherwise, a FunctionPass object will be - created. - - Examples - -------- - The following code creates a function level pass that performs constant - folding. - - .. code-block:: python - - @relay.ir_pass.function_pass(opt_level=2) - def transform(func, ctx): - return ir_pass.fold_constant(func) - - function_pass = transform - assert isinstance(function_pass, ir_pass.FunctionPass) - assert function_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = function_pass(m) - # Now constant folding should have been applied to every function in - # the provided module m. And the updated module will be returned. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the funtion pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_function_pass(pass_func): - """Internal function that creates a function pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _ir_pass.CreateFunctionPass(pass_func, opt_level, - name if name else pass_func.__name__, - required) - - if pass_func: - return create_function_pass(pass_func) - return create_function_pass - - -def sequential_pass(passes=None, opt_level=2, name="sequential_pass", - required=None, disabled=None): - """Create a sequential pass using a defined optimization function from - Python. Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to - apply when running a sequential pass. Pass dependency will be resolved in - the backend as well. - - Parameters - ---------- - passes : Optional[List[Pass]] - A sequence of passes candidate for optimization. - - opt_level : Optional[int] - The optimization level of this sequential pass. - - name : Optional[str] - The name of the sequential pass. - - required : Optional[List[str]] - The list of passes that the sequential pass is dependent on. - - disabled : Optional[List[str]] - A list of disabled passes. - - Returns - ------- - ret : Pass - A sequential pass built through pass_func. - """ - - passes = passes if passes else [] - if not isinstance(passes, (list, tuple)): - raise TypeError("passes must be a list of Pass objects.") - - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of list/tuple.") - - return _ir_pass.CreateSequentialPass(passes, opt_level, name, required, - disabled) - - def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py new file mode 100644 index 000000000000..877538afea34 --- /dev/null +++ b/python/tvm/relay/transform.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the pass manager for Relay which exposes different +granularity of interfaces for users to implement and use passes more +conveniently. +""" +import types + +from . import _transform +from .base import RelayNode, register_relay_node + + +@register_relay_node +class PassInfo(RelayNode): + """The class that contains the meta data required by a pass. It is the + container of information needed by running an optimization or analysis. + This class can be extended by adding new members when more meta data is + needed. + + Parameters + ---------- + name : str + The pass name. + + opt_level : int + The optimization level of this pass. + + required : List[str] + The list of passes that are required by a certain pass. + """ + + def __init__(self, name, opt_level, required=None): + self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level, + required) + + +@register_relay_node +class PassContext(RelayNode): + """The basis where a Relay optimization/analysis runs on. + Each pass context contains a number of auxiliary information that is used + to help an optimization pass. Such information includes the error reporter + to record the errors of during the optimization, etc. + """ + + def __init__(self): + self.__init_handle_by_constructor__(_transform.PassContext) + + +@register_relay_node +class Pass(RelayNode): + """The base class of all passes. All methods here are just simple wrappers + that are implemented in the backend. They are defined for users to + conveniently interact with the base class. + """ + + def set_pass_context(self, pass_ctx): + """Setup the pass context for analysis and optimizations. This context + could be shared by different passes for sequential passes. + + Parameters + ---------- + pass_ctx : PassContext + The context that is used to help perform a certain pass or a series + of passes. + """ + if not isinstance(pass_ctx, PassContext): + raise TypeError("pass_ctx is expected to be the PassContext type") + _transform.SetContext(self, pass_ctx) + + @property + def info(self): + """Get the pass meta.""" + return _transform.Info(self) + + def __call__(self, mod): + """Execute the pass. Note that for sequential pass, the dependency among + different passes will be resolved in the backend. + + Parameters + ---------- + mod : tvm.relay.Module + The module that a certain optimization is performed on. + + Returns + ------- + mod : tvm.relay.Module + The updated module after applying this pass. + """ + return _transform.RunPass(self, mod) + + +@register_relay_node +class ModulePass(Pass): + """A pass that works on tvm.relay.Module. Users don't need to interact with + this class directly. Instead, a module pass should be created through + `module_pass`, because the design of the `module_pass` API is flexible + enough to handle the creation of a module pass in different manners. In + addition, all members of a module pass can be accessed from the base class. + The same rule applies to FunctionPass and Sequential as well. + """ + + +@register_relay_node +class FunctionPass(Pass): + """A pass that works on each tvm.relay.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@register_relay_node +class Sequential(Pass): + """A pass that works on a sequence of pass objects. Multiple passes can be + executed sequentially using this class. + + Some typical usage of the sequential pass are: + 1. Users provide a list of passes for optimization. + 2. Only an optimization level is provided so that the backend system has + to glob all passes at this level and below to perform the optimizations. + Note that users can also provide a series of passes that they don't want to + apply when running a sequential pass. Pass dependency will be resolved in + the backend as well. + + Parameters + ---------- + passes : Optional[List[Pass]] + A sequence of passes candidate for optimization. + + opt_level : Optional[int] + The optimization level of this sequential pass. + + name : Optional[str] + The name of the sequential pass. + + required : Optional[List[str]] + The list of passes that the sequential pass is dependent on. + + disabled : Optional[List[str]] + A list of disabled passes. + """ + + def __init__(self, + passes=None, + opt_level=2, + name="sequential", + required=None, + disabled=None): + passes = passes if passes else [] + if not isinstance(passes, (list, tuple)): + raise TypeError("passes must be a list of Pass objects.") + + disabled = disabled if disabled else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled must be a list or tuple of pass names") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of list/tuple.") + + self.__init_handle_by_constructor__(_transform.Sequential, + passes, opt_level, name, required, + disabled) + + +def module_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a module pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created module level pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the module pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_module_pass : Union[Callable, ModulePass] + The callable that will create a module pass is returned when + pass_func is not passed in. Otherwise, a ModulePass object will be + directly created. + + Examples + -------- + The following code creates a module level pass and adds an abs function to + the module. + + .. code-block:: python + + @relay.transform.module_pass(opt_level=2) + def transform(mod, ctx): + tp = relay.TensorType((10,), "float32") + x = relay.var("x", tp) + gv = relay.GlobalVar("var") + func = relay.Function([x], relay.abs(x)) + new_mod = relay.Module({gv: func}) + new_mod.update(mod) + return new_mod + + module_pass = transform + assert isinstance(module_pass, transform.ModulePass) + assert module_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = module_pass(m) + # Now a function abs should be added to the module m. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the module pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_module_pass(pass_func): + """Internal function that creates a module pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + return _transform.CreateModulePass( + pass_func, opt_level, name if name else pass_func.__name__, + required) + + if pass_func: + return create_module_pass(pass_func) + return create_module_pass + + +def function_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a function pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + The callable that will create a function pass is returned when + pass_func is not passed in. Otherwise, a FunctionPass object will be + created. + + Examples + -------- + The following code creates a function level pass that performs constant + folding. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=2) + def transform(func, ctx): + return ir_pass.fold_constant(func) + + function_pass = transform + assert isinstance(function_pass, transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now constant folding should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the funtion pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_function_pass(pass_func): + """Internal function that creates a function pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + return _transform.CreateFunctionPass( + pass_func, opt_level, name if name else pass_func.__name__, + required) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi new file mode 100644 index 000000000000..343e89976b09 --- /dev/null +++ b/python/tvm/relay/transform.pyi @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from .base import NodeBase + + +class PassContext(NodeBase): + def __init__(self): + ... + +class PassInfo(NodeBase): + name = ... # type: str + opt_level = ... # type: int + required = ... # type: list + + def __init__(self, name, opt_level, required) + # type: (str, int, list) -> None + + +class Pass(NodeBase): + def __init__(self): + ... + + +class ModulePass(Pass): + name = ... # type: str + opt_level = ... # type: int + pass_func = ... # type: Callable + required = ... # type: list + + def __init__(self, name, opt_level, pass_func, required): + # type: (str, int, Callable, list) -> None + ... + + +class FunctionPass(Pass): + name = ... # type: str + opt_level = ... # type: int + pass_func = ... # type: Callable + required = ... # type: list + + def __init__(self, name, opt_level, pass_func, required): + # type: (str, int, Callable, list) -> None + ... + + +class Sequential(Pass): + name = ... # type: str + opt_level = ... # type: int + passes = ... # type: list + required = ... # type: list + disabled = ... # type: list + + def __init__(self, name, opt_level, passes, required, disabled): + # type: (str, int, list, list, list) -> None + ... diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index d607247b3bc8..a105b692aa9d 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -23,11 +23,11 @@ * \brief Relay pass manager implementation. */ #include -#include +#include namespace tvm { namespace relay { -namespace pass { +namespace transform { using tvm::IRPrinter; @@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode { RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); -class SequentialPass; - /*! - * \brief The SequentialPassNode contains a set of passes that transform Relay + * \brief The SequentialNode contains a set of passes that transform Relay * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly * perform a host of optimizations with a given optimization level and disabled * passes. */ -class SequentialPassNode : public PassNode { +class SequentialNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; @@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode { passes.push_back(pass); } - TVM_DLL static SequentialPass make(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled); - /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode { */ void SetContext(const PassContext& pass_ctx) final; - static constexpr const char* _type_key = "relay.SequentialPass"; - TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode); + static constexpr const char* _type_key = "relay.Sequential"; + TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); private: /*! @@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode { PassContext pass_ctx_; }; -RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass); - PassInfo PassInfoNode::make(int opt_level, std::string name, tvm::Array required) { auto pass_info = make_node(); @@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { return pval && pval->value != 0; } -SequentialPass SequentialPassNode::make(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled) { - auto n = make_node(); +Sequential::Sequential(tvm::Array passes, + PassInfo pass_info, + tvm::Array disabled) { + auto n = make_node(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); n->disabled = std::move(disabled); - return SequentialPass(n); + node_ = std::move(n); +} + +const SequentialNode* Sequential::operator->() const { + return static_cast(this->node_.get()); } // TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in -// a SequentialPass without the consideration of their orders. The phase +// a Sequential without the consideration of their orders. The phase // ordering problem needed to be handled in the future. -Module SequentialPassNode::operator()(const Module& module) const { +Module SequentialNode::operator()(const Module& module) const { Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; @@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const { return mod; } -void SequentialPassNode::ResolveDependency(const Module& mod) { +void SequentialNode::ResolveDependency(const Module& mod) { // TODO(zhiics) Implement it. // 1. Consider the required passes for each pass. // 2. Only resolve the enabled passes. @@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) { << "\n"; } -std::vector SequentialPassNode::DisabledPasses() const { +std::vector SequentialNode::DisabledPasses() const { std::vector ret; for (const auto& it : disabled) { const auto* str = it.as(); @@ -392,7 +388,7 @@ std::vector SequentialPassNode::DisabledPasses() const { return ret; } -void SequentialPassNode::SetContext(const PassContext& pass_ctx) { +void SequentialNode::SetContext(const PassContext& pass_ctx) { pass_ctx_ = pass_ctx; } @@ -414,21 +410,12 @@ Pass CreateFunctionPass( return FunctionPassNode::make(pass_func, pass_info); } -Pass CreateSequentialPass(const tvm::Array& passes, - int opt_level, - const std::string& name, - const tvm::Array& required, - const tvm::Array& disabled) { - PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - return SequentialPassNode::make(passes, pass_info, disabled); -} - TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_REGISTER_API("relay._ir_pass.PassInfo") +TVM_REGISTER_API("relay._transform.PassInfo") .set_body_typed(PassInfoNode::make); -TVM_REGISTER_API("relay._ir_pass.Info") +TVM_REGISTER_API("relay._transform.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); @@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_API("relay._ir_pass.CreateModulePass") +TVM_REGISTER_API("relay._transform.CreateModulePass") .set_body_typed(CreateModulePass); -TVM_REGISTER_API("relay._ir_pass.RunPass") +TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; Module mod = args[1]; @@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") +TVM_REGISTER_API("relay._transform.CreateFunctionPass") .set_body_typed(CreateFunctionPass); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << " at the optimization level " << pn->opt_level; }); -TVM_REGISTER_NODE_TYPE(SequentialPassNode); +TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") +TVM_REGISTER_API("relay._transform.Sequential") .set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; @@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") tvm::Array required = args[3]; tvm::Array disabled = args[4]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - *ret = SequentialPassNode::make(passes, pass_info, disabled); + *ret = Sequential(passes, pass_info, disabled); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const SequentialPassNode* node, - tvm::IRPrinter* p) { +.set_dispatch([](const SequentialNode* node, + tvm::IRPrinter* p) { const PassInfoNode* seq_pn = node->Info().operator->(); - p->stream << "Run SequentialPass pass: " << seq_pn->name + p->stream << "Run Sequential pass: " << seq_pn->name << " at the optimization level. " << seq_pn->opt_level; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { @@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "]"; }); -TVM_REGISTER_API("relay._ir_pass.SetContext") +TVM_REGISTER_API("relay._transform.SetContext") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; PassContext pass_ctx = args[1]; @@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext") TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_API("relay._ir_pass.PassContext") +TVM_REGISTER_API("relay._transform.PassContext") .set_body_typed(PassContextNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << "\n"; }); -} // namespace pass +} // namespace transform } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index b8216775ee1c..db346e7f712f 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -22,6 +22,7 @@ from tvm.relay import ExprFunctor from tvm.relay import Function, Call from tvm.relay import ir_pass +from tvm.relay import transform as _transform from tvm.relay.testing import ctx_list @@ -126,13 +127,13 @@ def test_module_pass(): opt_tester = OptTester(mod) pass_ctx = None - @ir_pass.module_pass(opt_level=opt_level, name=pass_name) + @_transform.module_pass(opt_level=opt_level, name=pass_name) def transform(expr, ctx): return opt_tester.transform(expr, ctx) def test_pass_registration(): mod_pass = transform - assert isinstance(mod_pass, ir_pass.ModulePass) + assert isinstance(mod_pass, _transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level @@ -140,8 +141,8 @@ def test_pass_registration(): def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) - mod_pass = ir_pass.module_pass(direct_transform, opt_level=3) - assert isinstance(mod_pass, ir_pass.ModulePass) + mod_pass = _transform.module_pass(direct_transform, opt_level=3) + assert isinstance(mod_pass, _transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 3 @@ -202,7 +203,7 @@ def test_function_pass(): opt_tester = OptTester(mod) pass_ctx = None - @ir_pass.function_pass(opt_level=opt_level, name=pass_name) + @_transform.function_pass(opt_level=opt_level, name=pass_name) def transform(expr, ctx): return opt_tester.transform(expr, ctx) @@ -212,7 +213,7 @@ def get_ref_log(): def test_pass_registration(): function_pass = transform - assert isinstance(function_pass, ir_pass.FunctionPass) + assert isinstance(function_pass, _transform.FunctionPass) pass_info = function_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level @@ -220,8 +221,8 @@ def test_pass_registration(): def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) - mod_pass = ir_pass.function_pass(direct_transform, opt_level=0) - assert isinstance(mod_pass, ir_pass.FunctionPass) + mod_pass = _transform.function_pass(direct_transform, opt_level=0) + assert isinstance(mod_pass, _transform.FunctionPass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 0 @@ -294,14 +295,14 @@ def get_ref_abs(): opt_tester = OptTester(mod) pass_ctx = None - @ir_pass.module_pass(opt_level=1) + @_transform.module_pass(opt_level=1) def mod_transform(expr, ctx): return opt_tester.transform(expr, ctx) module_pass = mod_transform # Register a function pass. - @ir_pass.function_pass(opt_level=1) + @_transform.function_pass(opt_level=1) def func_transform(expr, ctx): return opt_tester.transform(expr, ctx) @@ -310,25 +311,23 @@ def func_transform(expr, ctx): def test_pass_registration(): passes = [module_pass, function_pass] opt_level = 2 - pass_name = "sequential_pass" - sequential_pass = ir_pass.sequential_pass(passes=passes, - opt_level=opt_level) - assert isinstance(sequential_pass, ir_pass.SequentialPass) - pass_info = sequential_pass.info + pass_name = "sequential" + sequential = _transform.Sequential(passes=passes, opt_level=opt_level) + pass_info = sequential.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level def test_no_pass(): passes = [] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) mod_func = ret_mod[v_sub] check_func(sub, mod_func) def test_only_module_pass(): passes = [module_pass] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) @@ -341,8 +340,8 @@ def test_only_module_pass(): def test_only_function_pass(): # Check the subtract function. passes = [function_pass] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -355,8 +354,8 @@ def test_multiple_passes(): # function pass. mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) # Check the abs function is added. abs_var, abs_func = get_var_func() From 0dbfa2aca4ea2c9f830cf91e9e07b1390005e3e9 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 22 May 2019 13:57:53 -0700 Subject: [PATCH 035/176] [Relay][Prelude] Remove Peano nats from the prelude (#3045) --- python/tvm/relay/prelude.py | 132 ++++++------- python/tvm/relay/testing/__init__.py | 1 + python/tvm/relay/testing/nat.py | 184 ++++++++++++++++++ tests/python/relay/test_adt.py | 121 ++++++------ tests/python/relay/test_ir_well_formed.py | 8 +- tests/python/relay/test_pass_alpha_equal.py | 6 +- tests/python/relay/test_pass_gradient.py | 4 +- .../relay/test_pass_to_a_normal_form.py | 15 +- 8 files changed, 326 insertions(+), 145 deletions(-) create mode 100644 python/tvm/relay/testing/nat.py diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index ff823c3413fa..92647e5b14b4 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,7 +17,8 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """Adds certain standard global functions and ADT definitions to the module.""" from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type -from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem +from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const +from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard @@ -34,6 +35,7 @@ def define_list_adt(self): self.cons = Constructor("cons", [a, self.l(a)], self.l) self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + def define_list_hd(self): """Defines a function to get the head of a list. Assume the list has at least one element. @@ -48,6 +50,7 @@ def define_list_hd(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) + def define_list_tl(self): """Defines a function to get the tail of a list. @@ -61,39 +64,44 @@ def define_list_tl(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z) self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a]) + def define_list_nth(self): """Defines a function to get the nth element of a list. - nth(l) : list[a] -> a + nth(l) : list[a] -> Tensor[(), int32] -> a """ self.nth = GlobalVar("nth") a = TypeVar("a") x = Var("x", self.l(a)) - n = Var("n", self.nat()) + n = Var("n", scalar_type('int32')) + + body = If(equal(n, const(0)), + self.hd(x), + self.nth(self.tl(x), subtract(n, const(1)))) + + self.mod[self.nth] = Function([x, n], body, a, [a]) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.hd(x)) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y)) - self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) def define_list_update(self): """Defines a function to update the nth element of a list and return the updated list. - update(l, i, v) : list[a] -> nat -> a -> list[a] + update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a] """ self.update = GlobalVar("update") a = TypeVar("a") l = Var("l", self.l(a)) - n = Var("n", self.nat()) + n = Var("n", scalar_type('int32')) v = Var("v", a) - y = Var("y") + body = If(equal(n, const(0)), + self.cons(v, self.tl(l)), + self.cons(self.hd(l), + self.update(self.tl(l), + subtract(n, const(1)), + v))) - z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l))) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.cons(self.hd(l), self.update(self.tl(l), y, v))) + self.mod[self.update] = Function([l, n, v], body, self.l(a), [a]) - self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a]) def define_list_map(self): """Defines a function for mapping a function over a list's @@ -114,6 +122,7 @@ def define_list_map(self): self.cons(f(y), self.map(f, z))) self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b]) + def define_list_foldl(self): """Defines a left-way fold over a list. @@ -136,6 +145,7 @@ def define_list_foldl(self): self.mod[self.foldl] = Function([f, av, bv], Match(bv, [nil_case, cons_case]), a, [a, b]) + def define_list_foldr(self): """Defines a right-way fold over a list. @@ -158,6 +168,7 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), b, [a, b]) + def define_list_foldr1(self): """Defines a right-way fold over a nonempty list. @@ -196,6 +207,7 @@ def define_list_concat(self): self.foldr(updater, l2, l1), self.l(a), [a]) + def define_list_filter(self): """Defines a function that filters a list. @@ -214,6 +226,7 @@ def define_list_filter(self): If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t))) self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a]) + def define_list_zip(self): """Defines a function that combines two lists into a list of tuples of their elements. @@ -238,6 +251,7 @@ def define_list_zip(self): self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]), self.l(TupleType([a, b])), [a, b]) + def define_list_rev(self): """Defines a function that reverses a list. @@ -253,6 +267,7 @@ def define_list_rev(self): self.foldl(updater, self.nil(), l), self.l(a), [a]) + def define_list_map_accumr(self): """Defines an accumulative map, which is a fold that simulataneously updates an accumulator value and a list of results. @@ -282,6 +297,7 @@ def define_list_map_accumr(self): TupleType([a, self.l(c)]), [a, b, c]) + def define_list_map_accuml(self): """Defines an accumulative map, which is a fold that simulataneously updates an accumulator value and a list of results. @@ -321,6 +337,7 @@ def define_optional_adt(self): self.none = Constructor("none", [], self.optional) self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) + def define_list_unfoldr(self): """Defines a function that builds up a list starting from a seed value. @@ -343,6 +360,7 @@ def define_list_unfoldr(self): self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]), self.l(b), [a, b]) + def define_list_unfoldl(self): """Defines a function that builds up a list starting from a seed value. @@ -362,52 +380,29 @@ def define_list_unfoldl(self): self.rev(self.unfoldr(f, s)), self.l(b), [a, b]) - def define_nat_adt(self): - """Defines a Peano (unary) natural number ADT. - Zero is represented by z(). s(n) adds 1 to a nat n.""" - self.nat = GlobalTypeVar("nat") - self.z = Constructor("z", [], self.nat) - self.s = Constructor("s", [self.nat()], self.nat) - self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) - - def define_nat_double(self): - """Defines a function that doubles a nat.""" - self.double = GlobalVar("double") - x = Var("x", self.nat()) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.z()) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.s(self.s(self.double(y)))) - self.mod[self.double] = Function([x], Match(x, [z_case, s_case])) - - def define_nat_add(self): - """Defines a function that adds two nats.""" - self.add = GlobalVar("add") - x = Var("x", self.nat()) - y = Var("y", self.nat()) - a = Var("a") - z_case = Clause(PatternConstructor(self.z), y) - s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), - self.s(self.add(a, y))) - self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) def define_list_sum(self): - """Defines a function that computes the sum of a list of nats.""" + """Defines a function that computes the sum of a list of integer scalars.""" self.sum = GlobalVar("sum") - a = Var("a", self.l(self.nat())) - self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + a = Var("a", self.l(scalar_type('int32'))) + x = Var('x') + y = Var('y') + addf = Function([x, y], add(x, y)) + self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a)) + def define_list_length(self): - """Defines a function that returns the length of a list as a nat""" + """Defines a function that returns the length of a list""" self.length = GlobalVar("length") a = TypeVar("a") x = Var("x", self.l(a)) y = Var("y") - nil_case = Clause(PatternConstructor(self.nil), self.z()) + nil_case = Clause(PatternConstructor(self.nil), const(0)) cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), - self.s(self.length(y))) + add(const(1), self.length(y))) self.mod[self.length] = Function([x], - Match(x, [nil_case, cons_case]), None, [a]) + Match(x, [nil_case, cons_case]), scalar_type('int32'), [a]) + def define_tree_adt(self): """Defines a tree ADT. A tree can contain any type. @@ -420,6 +415,7 @@ def define_tree_adt(self): self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) self.mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + def define_tree_map(self): """Defines a function that maps over a tree. The function is applied to each subtree's contents. @@ -439,23 +435,24 @@ def define_tree_map(self): self.mod[self.tmap] = Function([f, t], Match(t, [rose_case]), self.tree(b), [a, b]) + def define_tree_size(self): - """Defines a function that computes the size of a tree as a nat. + """Defines a function that computes the size of a tree. - Signature: fn(t : tree[a]) -> nat + Signature: fn(t : tree[a]) -> Tensor[(), int32] """ self.size = GlobalVar("size") a = TypeVar("a") t = Var("t", self.tree(a)) - x = Var("x", self.tree(a)) z = Var("z") rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), - self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + add(const(1), self.sum(self.map(self.size, z)))) self.mod[self.size] = Function([t], - Match(t, [rose_case]), self.nat(), [a]) + Match(t, [rose_case]), scalar_type('int32'), [a]) + def define_id(self): - """Defines a function that return it's argument. + """Defines a function that return its argument. Signature: fn(x : a) -> a """ @@ -466,7 +463,7 @@ def define_id(self): def define_compose(self): - """Defines a function that compose two function. + """Defines a function that composes two function. Signature: fn(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c """ @@ -484,24 +481,26 @@ def define_compose(self): def define_iterate(self): - """Define a function that take a number n, a function f, - and return a closure that apply f n time on it's argument. + """Defines a function that take a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. - Signature: fn(n : nat, f : fn(a) -> a) -> fn(a) -> a + Signature: fn(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a """ self.iterate = GlobalVar("iterate") a = TypeVar("a") f = Var("f", FuncType([a], a)) - x = Var("x", self.nat()) - y = Var("y", self.nat()) - z_case = Clause(PatternConstructor(self.z), self.id) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.compose(f, self.iterate(f, y))) + x = Var("x", scalar_type('int32')) + body = If(equal(x, const(0)), + self.id, + self.compose(f, + self.iterate(f, subtract(x, const(1))))) self.mod[self.iterate] = Function([f, x], - Match(x, [z_case, s_case]), + body, FuncType([a], a), [a]) + def __init__(self, mod): self.mod = mod self.define_list_adt() @@ -522,9 +521,6 @@ def __init__(self, mod): self.define_list_unfoldr() self.define_list_unfoldl() - self.define_nat_adt() - self.define_nat_double() - self.define_nat_add() self.define_list_length() self.define_list_nth() self.define_list_update() diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index b4a8394e2659..192afe1ef914 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -30,3 +30,4 @@ from .config import ctx_list from .init import create_workload +from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py new file mode 100644 index 000000000000..4c0c87ce8a9e --- /dev/null +++ b/python/tvm/relay/testing/nat.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Defines a unary natural number (Peano natural number) abstract +data type for Relay and provides some utility functions for it. +Nats are useful for testing purposes, as they make it easy to write +test cases for recursion and pattern matching.""" + +from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar +from tvm.relay.backend.interpreter import ConstructorValue +from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType + +def define_nat_adt(prelude): + """Defines a Peano (unary) natural number ADT. + Zero is represented by z(). s(n) adds 1 to a nat n. + Adds the fields nat, z, and s to the preluide, representing + (respectively) the nat ADT and the z and s constructors. + """ + prelude.nat = GlobalTypeVar("nat") + prelude.z = Constructor("z", [], prelude.nat) + prelude.s = Constructor("s", [prelude.nat()], prelude.nat) + prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s]) + + +def define_nat_double(prelude): + """Defines a function that doubles a nat. Adds a field called + 'double' to the prelude, giving the GlobalVar pointing to + the function. + """ + prelude.double = GlobalVar("double") + x = Var("x", prelude.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(prelude.z), prelude.z()) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.s(prelude.s(prelude.double(y)))) + prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case])) + + +def define_nat_add(prelude): + """Defines a function that adds two nats and adds a field to the + prelude 'add' giving the GlobalVar pointing to that function. + """ + prelude.add = GlobalVar("add") + x = Var("x", prelude.nat()) + y = Var("y", prelude.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(prelude.z), y) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]), + prelude.s(prelude.add(a, y))) + prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case])) + + +# versions of prelude functions that use nats instead of scalars + +def define_nat_nth(prelude): + """Defines a function to get the nth eleemnt of a list using + a nat to index into the list. + + nat_nth(l, n): fun(list[a], nat) -> a + """ + prelude.nat_nth = GlobalVar("nat_nth") + a = TypeVar("a") + x = Var("x", prelude.l(a)) + n = Var("n", prelude.nat()) + y = Var("y") + + z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x)) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.nat_nth(prelude.tl(x), y)) + + prelude.mod[prelude.nat_nth] = Function([x, n], + Match(n, [z_case, s_case]), + a, [a]) + + +def define_nat_update(prelude): + """Defines a function to update the nth element of a list and return the updated list. + + nat_update(l, i, v) : fun(list[a], nat, a) -> list[a] + """ + prelude.nat_update = GlobalVar("nat_update") + a = TypeVar("a") + # pylint: disable=invalid-name + l = Var("l", prelude.l(a)) + n = Var("n", prelude.nat()) + v = Var("v", a) + y = Var("y") + + z_case = Clause(PatternConstructor(prelude.z), + prelude.cons(v, prelude.tl(l))) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.cons( + prelude.hd(l), + prelude.nat_update(prelude.tl(l), y, v))) + + prelude.mod[prelude.nat_update] = Function([l, n, v], + Match(n, [z_case, s_case]), + prelude.l(a), [a]) + + +def define_nat_iterate(prelude): + """Defines a function that takes a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. + + Signature: fn(fn(a) -> a, nat) -> fn(a) -> a + """ + prelude.nat_iterate = GlobalVar("nat_iterate") + a = TypeVar("a") + f = Var("f", FuncType([a], a)) + x = Var("x", prelude.nat()) + y = Var("y", prelude.nat()) + + z_case = Clause(PatternConstructor(prelude.z), prelude.id) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.compose(f, prelude.nat_iterate(f, y))) + + prelude.mod[prelude.nat_iterate] = Function([f, x], + Match(x, [z_case, s_case]), + FuncType([a], a), + [a]) + + +def add_nat_definitions(prelude): + """Given a Relay prelude, adds a Peano nat ADT, as well as functions + for adding nats and doubling nats. It also adds versions of + update, nth, and iterate that take nats instead of scalars (the + names are prefixed with 'nat_').""" + define_nat_adt(prelude) + define_nat_double(prelude) + define_nat_add(prelude) + define_nat_nth(prelude) + define_nat_update(prelude) + define_nat_iterate(prelude) + + +# helper functions for working with nats + + +def count(n): + """Takes a ConstructorValue corresponding to a nat ADT + and converts it into a Python integer. This is an example of + using an ADT value in Python. + """ + assert isinstance(n, ConstructorValue) + if n.constructor.name_hint == 'z': + return 0 + assert n.constructor.name_hint == 's' + return 1 + count(n.fields[0]) + + +def make_nat_value(prelude, n): + """The inverse of count(): Given a non-negative Python integer, + constructs a ConstructorValue representing that value as a nat. + """ + if n == 0: + return ConstructorValue(prelude.z, [], []) + return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], []) + + +def make_nat_expr(prelude, n): + """Given a non-negative Python integer, constructs a Python + expression representing that integer's value as a nat. + """ + assert n >= 0 + ret = prelude.z() + while n > 0: + ret = prelude.s(ret) + n = n - 1 + return ret diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 58ab0c481f9c..77f4ab1f16a0 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -14,15 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import tvm from tvm import relay from tvm.relay.ir_pass import infer_type from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay import testing, create_executor from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr mod = relay.Module() p = Prelude(mod) +add_nat_definitions(p) + ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") @@ -67,15 +71,6 @@ compose = p.compose iterate = p.iterate -# this is an example of using the adt value in python side -def count(n): - assert isinstance(n, ConstructorValue) - if n.constructor.name_hint == 's': - return 1 + count(n.fields[0]) - else: - assert n.constructor.name_hint == 'z' - return 0 - # this is an example of creating the adt value in python side def make_nat(n): if n != 0: @@ -83,7 +78,7 @@ def make_nat(n): else: return ConstructorValue(z, [], []) -def build_nat(n): +def make_nat_expr(n): assert n >= 0 ret = z() while n > 0: @@ -115,8 +110,14 @@ def tree_to_dict(t): ret['children'].append(l) return ret + +# turns a scalar-valued relay tensor value into a python number +def get_scalar(tv): + return tv.asnumpy().item() + + def test_nat_value(): - assert count(make_nat(10)) == 10 + assert count(make_nat_value(p, 10)) == 10 assert count(intrp.evaluate(s(s(z())))) == 2 @@ -145,7 +146,7 @@ def test_hd_tl(): expected = list(range(10)) l = nil() for i in reversed(expected): - l = cons(build_nat(i), l) + l = cons(make_nat_expr(i), l) got = [] for i in range(len(expected)): @@ -158,36 +159,35 @@ def test_nth(): expected = list(range(10)) l = nil() for i in reversed(expected): - l = cons(build_nat(i), l) + l = cons(relay.const(i), l) - got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(nth(l, build_nat(i))))) + item = intrp.evaluate(nth(l, relay.const(i))) + assert get_scalar(item) == i - assert got == expected def test_update(): expected = list(range(10)) l = nil() # create zero initialized list for i in range(len(expected)): - l = cons(build_nat(0), l) + l = cons(make_nat_expr(0), l) # set value for i, v in enumerate(expected): - l = update(l, build_nat(i), build_nat(v)) + l = update(l, relay.const(i), make_nat_expr(v)) got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(nth(l, build_nat(i))))) + got.append(count(intrp.evaluate(nth(l, relay.const(i))))) assert got == expected def test_length(): a = relay.TypeVar("a") - assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) + assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a]) res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) - assert count(res) == 3 + assert get_scalar(res) == 3 def test_map(): @@ -216,9 +216,9 @@ def test_foldl(): y = relay.Var("y") rev_dup = relay.Function([y, x], cons(x, cons(x, y))) res = intrp.evaluate(foldl(rev_dup, nil(), - cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) reversed = to_list(res) assert len(reversed) == 6 assert count(reversed[0]) == 3 and count(reversed[1]) == 3 @@ -237,9 +237,9 @@ def test_foldr(): y = relay.Var("y") identity = relay.Function([x, y], cons(x, y)) res = intrp.evaluate(foldr(identity, nil(), - cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) same = to_list(res) assert len(same) == 3 assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3 @@ -255,25 +255,25 @@ def test_foldr1(): y = relay.Var("y") f = relay.Function([x, y], add(x, y)) res = intrp.evaluate(foldr1(f, - cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) assert count(res) == 6 def test_sum(): - assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) - res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil())))) - assert count(res) == 3 + assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32')) + res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) + assert get_scalar(res) == 3 def test_concat(): a = relay.TypeVar("a") assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a]) - l1 = cons(build_nat(1), cons(build_nat(2), nil())) - l2 = cons(build_nat(3), cons(build_nat(4), nil())) + l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil())) + l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil())) res = intrp.evaluate(concat(l1, l2)) catted = to_list(res) @@ -305,12 +305,12 @@ def test_filter(): ])) res = intrp.evaluate( filter(greater_than_one, - cons(build_nat(1), - cons(build_nat(1), - cons(build_nat(3), - cons(build_nat(1), - cons(build_nat(5), - cons(build_nat(1), + cons(make_nat_expr(1), + cons(make_nat_expr(1), + cons(make_nat_expr(3), + cons(make_nat_expr(1), + cons(make_nat_expr(5), + cons(make_nat_expr(1), nil())))))))) filtered = to_list(res) assert len(filtered) == 2 @@ -325,7 +325,7 @@ def test_zip(): l(relay.TupleType([a, b])), [a, b]) assert mod[zip].checked_type == expected_type - l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) l2 = cons(nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), @@ -342,7 +342,7 @@ def test_zip(): assert len(to_list(zipped[2][1])) == 2 # test truncation - l3 = cons(build_nat(4), cons(build_nat(5), nil())) + l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil())) shorter_res = intrp.evaluate(zip(l3, l2)) truncated = to_list(shorter_res) assert len(truncated) == 2 @@ -363,9 +363,9 @@ def test_rev(): a = relay.TypeVar("a") assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a]) - res = intrp.evaluate(rev(cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + res = intrp.evaluate(rev(cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) reversed = to_list(res) assert len(reversed) == 3 @@ -392,7 +392,7 @@ def test_unfoldr(): relay.Clause(relay.PatternConstructor(z, []), none()) ])) - res = intrp.evaluate(unfoldr(count_down, build_nat(3))) + res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -419,7 +419,7 @@ def test_unfoldl(): relay.Clause(relay.PatternConstructor(z, []), none()) ])) - res = intrp.evaluate(unfoldl(count_down, build_nat(3))) + res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -444,7 +444,7 @@ def test_map_accumr(): relay.Tuple([add(x, acc), add(x, acc)])) - vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) sum = count(res[0]) @@ -472,7 +472,7 @@ def test_map_accuml(): add_to_acc = relay.Function([acc, x], relay.Tuple([add(x, acc), x])) - vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) sum = count(res[0]) @@ -497,8 +497,8 @@ def test_optional_matching(): ])) res = intrp.evaluate(foldr(condense, nil(), cons( - some(build_nat(3)), - cons(none(), cons(some(build_nat(1)), nil()))))) + some(make_nat_expr(3)), + cons(none(), cons(some(make_nat_expr(1)), nil()))))) reduced = to_list(res) assert len(reduced) == 2 @@ -532,7 +532,7 @@ def test_tmap(): def test_size(): a = relay.TypeVar("a") lhs = mod[size].checked_type - rhs = relay.FuncType([tree(a)], nat(), [a]) + rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a]) assert lhs == rhs root = rose(z(), cons(rose(z(), nil()), @@ -540,7 +540,7 @@ def test_size(): nil()))) t = rose(z(), cons(root, cons(root, cons(root, nil())))) res = intrp.evaluate(size(t)) - assert count(res) == 10 + assert get_scalar(res) == 10 def test_wildcard_match_solo(): @@ -601,10 +601,10 @@ def test_nested_matches(): inner_match) ]), l(a), [a]) - first_list = cons(build_nat(1), cons(build_nat(2), - cons(build_nat(3), nil()))) - second_list = cons(build_nat(4), cons(build_nat(5), - cons(build_nat(6), nil()))) + first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))) + second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), + cons(make_nat_expr(6), nil()))) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) @@ -660,6 +660,7 @@ def test_nested_pattern_match(): assert count(res) == 2 + def test_compose(): n = relay.Var('n') inc = relay.Function([n], s(n)) @@ -667,11 +668,13 @@ def test_compose(): res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))])) assert count(res) == 5 + def test_iterate(): - expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)]) + expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)]) res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 + if __name__ == "__main__": test_nat_constructor() test_double() diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index e69f839e3ee6..3cf73ae2cc66 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -53,10 +53,12 @@ def test_adt(): mod = relay.Module() p = Prelude(mod) x = relay.Var("x") - s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x) + some_case = relay.Clause(relay.PatternConstructor(p.some, + [relay.PatternVar(x)]), + x) default_case = relay.Clause(relay.PatternVar(x), x) - m0 = relay.Match(p.z(), [default_case]) - m1 = relay.Match(p.z(), [s_case, default_case]) + m0 = relay.Match(p.none(), [default_case]) + m1 = relay.Match(p.none(), [some_case, default_case]) assert well_formed(m0) assert not well_formed(m1) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index f00dc85eb7f8..478b433180b9 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -521,7 +521,7 @@ def test_match_alpha_equal(): relay.PatternVar(a)]), p.cons(z, a)) - data = p.cons(p.z(), p.cons(p.z(), p.nil())) + data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil())) match = relay.Match(data, [nil_case, cons_case]) equivalent = relay.Match(data, [nil_case, equivalent_cons]) @@ -547,8 +547,8 @@ def test_match_alpha_equal(): relay.Clause(relay.PatternWildcard(), p.nil()) ]) wrong_constructors = relay.Match(data, [ - relay.Clause(relay.PatternConstructor(p.z), p.nil()), - relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), + relay.Clause(relay.PatternConstructor(p.none), p.nil()), + relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]), p.cons(x, p.nil())) ]) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index f5968a41f028..d99bee58b99b 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -19,6 +19,7 @@ from tvm.relay.ir_pass import free_vars, free_type_vars, gradient from tvm.relay import create_executor from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, make_nat_expr import numpy as np @@ -174,13 +175,14 @@ def test_tuple(): def test_pow(): mod = relay.Module() p = Prelude(mod) + add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) - func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i])) + func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 2e95dbe55121..f395580a3f84 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -21,6 +21,7 @@ from tvm.relay import op, create_executor from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, count def check_eval(expr, expected_result, mod=None, rtol=1e-07): @@ -130,19 +131,10 @@ def test_ref(): check_eval(to_a_normal_form(body), 3) -# this is an example of using the adt value in python side -def count(n): - assert isinstance(n, ConstructorValue) - if n.constructor.name_hint == 's': - return 1 + count(n.fields[0]) - else: - assert n.constructor.name_hint == 'z' - return 0 - - -def test_add(): +def test_nat_add(): mod = relay.Module() p = Prelude(mod) + add_nat_definitions(p) nat = p.nat add = p.add s = p.s @@ -183,4 +175,5 @@ def test_function(): test_ref() test_add() test_let() + test_nat_add() test_function() From 522105addd98a66536ef4fdd4aa6756fb14dd3b6 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 22 May 2019 15:41:02 -0700 Subject: [PATCH 036/176] Register SkipVectorize (#3228) --- src/api/api_pass.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index d6c92aee94d1..e5b003cafb87 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -130,6 +130,7 @@ REGISTER_PASS(RewriteUnsafeSelect); REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); REGISTER_PASS(VectorizeLoop); +REGISTER_PASS(SkipVectorize); REGISTER_PASS(UnrollLoop); REGISTER_PASS(InjectCopyIntrin); REGISTER_PASS(ThreadSync); From 982841e86ea4ffa0386750ca86caf178d897676d Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 22 May 2019 16:22:13 -0700 Subject: [PATCH 037/176] [3rdparty] sync submodules (#3229) --- 3rdparty/HalideIR | 2 +- 3rdparty/dlpack | 2 +- 3rdparty/dmlc-core | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index ec9585a5a5df..32057b53eee8 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit ec9585a5a5df3de91e8916ac2d27a4a509eac5fc +Subproject commit 32057b53eee870d73c6c21dc820d6546b4d9a13f diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 5c792cef3aee..0acb731e0e43 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 5c792cef3aee54ad8b7000111c9dc1797f327b59 +Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 82bf4c2e2af3..3943914eed66 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 82bf4c2e2af312b3d52513aa727483803a2f8734 +Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f From e6b68b45c90ee176a6480dcb1ecbe289bbcb1518 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Thu, 23 May 2019 10:13:11 -0700 Subject: [PATCH 038/176] [GraphRuntime] Debug graph runtime (#3232) --- python/tvm/contrib/debugger/debug_result.py | 13 ++-- python/tvm/contrib/debugger/debug_runtime.py | 42 ++++++++----- .../graph/debug/graph_runtime_debug.cc | 63 +++++++------------ .../unittest/test_runtime_graph_debug.py | 4 -- 4 files changed, 55 insertions(+), 67 deletions(-) diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index c53a2c287339..882364dd3971 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -207,10 +207,8 @@ def dump_graph_json(self, graph): def display_debug_result(self): """Displays the debugger result" """ - header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Start Time", \ - "End Time", "Shape", "Inputs", "Outputs"] - lines = ["---------", "---", "--------", "-------", "----------", \ - "--------", "-----", "------", "-------"] + header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Shape", "Inputs", "Outputs"] + lines = ["---------", "---", "--------", "-------", "-----", "------", "-------"] eid = 0 data = [] total_time = sum(time[0] for time in self._time_list) @@ -223,12 +221,11 @@ def display_debug_result(self): continue name = node['name'] shape = str(self._output_tensor_list[eid].shape) - time_us = round(time[0] * 1000000, 2) - time_percent = round(((time[0] / total_time) * 100), 2) + time_us = round(time[0] * 1000000, 3) + time_percent = round(((time[0] / total_time) * 100), 3) inputs = str(node['attrs']['num_inputs']) outputs = str(node['attrs']['num_outputs']) - node_data = [name, op, time_us, time_percent, str(time[1]), str(time[2]), \ - shape, inputs, outputs] + node_data = [name, op, time_us, time_percent, shape, inputs, outputs] data.append(node_data) eid += 1 fmt = "" diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 01cda35769a5..f77a927eeabf 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -19,7 +19,6 @@ import os import tempfile import shutil -from datetime import datetime from tvm._ffi.base import string_types from tvm._ffi.function import get_global_func from tvm.contrib import graph_runtime @@ -30,6 +29,7 @@ _DUMP_ROOT_PREFIX = "tvmdbg_" _DUMP_PATH_PREFIX = "_tvmdbg_" + def create(graph_json_str, libmod, ctx, dump_root=None): """Create a runtime executor module given a graph and module. @@ -62,17 +62,23 @@ def create(graph_json_str, libmod, ctx, dump_root=None): try: fcreate = get_global_func("tvm.graph_runtime_debug.create") except ValueError: - raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \ - "config.cmake and rebuild TVM to enable debug mode") + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " + "config.cmake and rebuild TVM to enable debug mode" + ) ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): libmod = rpc_base._ModuleHandle(libmod) try: - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.remote_create") + fcreate = ctx[0]._rpc_sess.get_function( + "tvm.graph_runtime_debug.remote_create" + ) except ValueError: - raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \ - "config.cmake and rebuild TVM to enable debug mode") + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " + "config.cmake and rebuild TVM to enable debug mode" + ) func_obj = fcreate(graph_json_str, libmod, *device_type_id) return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root) @@ -100,10 +106,10 @@ class GraphModuleDebug(graph_runtime.GraphModule): To select which folder the outputs should be kept. None will make a temp folder in /tmp/tvmdbg and does the dumping """ + def __init__(self, module, ctx, graph_json_str, dump_root): self._dump_root = dump_root self._dump_path = None - self._debug_run = module["debug_run"] self._get_output_by_layer = module["get_output_by_layer"] self._run_individual = module["run_individual"] graph_runtime.GraphModule.__init__(self, module) @@ -181,13 +187,10 @@ def _run_debug(self): Time consumed for each execution will be set as debug output. """ - self.debug_datum._time_list = [] - + self.debug_datum._time_list = [ + [float(t) * 1e-6] for t in self.run_individual(10, 1, 1) + ] for i, node in enumerate(self.debug_datum.get_graph_nodes()): - start_time = datetime.now().time() - time_stamp = self._debug_run(i) - end_time = datetime.now().time() - self.debug_datum._time_list.append([time_stamp, start_time, end_time]) num_outputs = self.debug_datum.get_graph_node_output_num(node) for j in range(num_outputs): out_tensor = self._get_output_by_layer(i, j) @@ -212,8 +215,13 @@ def debug_get_output(self, node, out): ret = output_tensors[node] except: node_list = output_tensors.keys() - raise RuntimeError("Node " + node + " not found, available nodes are: " - + str(node_list) + ".") + raise RuntimeError( + "Node " + + node + + " not found, available nodes are: " + + str(node_list) + + "." + ) elif isinstance(node, int): output_tensors = self.debug_datum._output_tensor_list ret = output_tensors[node] @@ -242,7 +250,9 @@ def run(self, **input_dict): self.debug_datum.display_debug_result() def run_individual(self, number, repeat=1, min_repeat_ms=0): - self._run_individual(number, repeat, min_repeat_ms) + ret = self._run_individual(number, repeat, min_repeat_ms) + return ret.strip(",").split(",") if ret else [] + def exit(self): """Exits the dump folder and all its contents""" diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 560bf3da238e..2b26ae541b5f 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,7 +24,9 @@ #include #include #include + #include +#include #include "../graph_runtime.h" namespace tvm { @@ -39,40 +41,23 @@ namespace runtime { class GraphRuntimeDebug : public GraphRuntime { public: /*! - * \brief Run each operation and get the output. - * \param index The index of op which needs to be run. - * \return the elapsed time. - */ - double DebugRun(size_t index) { - CHECK(index < op_execs_.size()); - TVMContext ctx = data_entry_[entry_id(index, 0)]->ctx; - auto tbegin = std::chrono::high_resolution_clock::now(); - if (op_execs_[index]) { - op_execs_[index](); - } - TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); - auto tend = std::chrono::high_resolution_clock::now(); - double time = std::chrono::duration_cast >( - tend - tbegin).count(); - return time; - } - - /*! - * \brief Run each operation in the graph and print out the runtime per op. + * \brief Run each operation in the graph and get the time per op for all ops. * \param number The number of times to run this function for taking average. * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warmed up and will be discarded in case - there is lazy initialization. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warmed up and will be discarded in case + * there is lazy initialization. * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * \return Comma seperated string containing the elapsed time per op for the last + * iteration only, because returning a long string over rpc can be expensive. */ - void RunIndividual(int number, int repeat, int min_repeat_ms) { + std::string RunIndividual(int number, int repeat, int min_repeat_ms) { // warmup run GraphRuntime::Run(); - + std::ostringstream os; std::vector time_per_op(op_execs_.size(), 0); for (int i = 0; i < repeat; ++i) { std::chrono::time_point< @@ -96,7 +81,7 @@ class GraphRuntimeDebug : public GraphRuntime { auto op_tend = std::chrono::high_resolution_clock::now(); double op_duration = std::chrono::duration_cast< std::chrono::duration >(op_tend - op_tbegin).count(); - time_per_op[index] += op_duration * 1000; // ms + time_per_op[index] += op_duration * 1e6; // us } } } @@ -105,16 +90,20 @@ class GraphRuntimeDebug : public GraphRuntime { (tend - tbegin).count() * 1000; } while (duration_ms < min_repeat_ms); - LOG(INFO) << "Repeat: " << i; + LOG(INFO) << "Iteration: " << i; int op = 0; for (size_t index = 0; index < time_per_op.size(); index++) { if (op_execs_[index]) { time_per_op[index] /= number; LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " - << time_per_op[index] << " ms/iter"; + << time_per_op[index] << " us/iter"; } } } + for (size_t index = 0; index < time_per_op.size(); index++) { + os << time_per_op[index] << ","; + } + return os.str(); } /*! @@ -182,11 +171,7 @@ PackedFunc GraphRuntimeDebug::GetFunction( const std::string& name, const std::shared_ptr& sptr_to_self) { // return member functions during query. - if (name == "debug_run") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->DebugRun(static_cast(args[0].operator int64_t())); - }); - } else if (name == "get_output_by_layer") { + if (name == "get_output_by_layer") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutputByLayer(args[0], args[1]); }); @@ -206,7 +191,7 @@ PackedFunc GraphRuntimeDebug::GetFunction( CHECK_GT(number, 0); CHECK_GT(repeat, 0); CHECK_GE(min_repeat_ms, 0); - this->RunIndividual(number, repeat, min_repeat_ms); + *rv = this->RunIndividual(number, repeat, min_repeat_ms); }); } else { return GraphRuntime::GetFunction(name, sptr_to_self); diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index 3de270732403..717b23c22689 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -100,9 +100,6 @@ def check_verify(): out = mod.get_output(0, tvm.nd.empty((n,))) np.testing.assert_equal(out.asnumpy(), a + 1) - #test individual run - mod.run_individual(20, 2, 1) - mod.exit() #verify dump root delete after cleanup assert(not os.path.exists(directory)) @@ -129,7 +126,6 @@ def check_remote(): mod.run(x=tvm.nd.array(a, ctx)) out = tvm.nd.empty((n,), ctx=ctx) out = mod.get_output(0, out) - mod.run_individual(20, 2, 1) np.testing.assert_equal(out.asnumpy(), a + 1) check_verify() From 5ff99b66cbd695fddde03469c21c4948b91097db Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 23 May 2019 10:17:03 -0700 Subject: [PATCH 039/176] [NODE] Macro to define NodeRef methods, constructor style example (#3224) --- include/tvm/arithmetic.h | 44 ++++++++++++++-------- include/tvm/base.h | 61 ++++++++++++++++++------------- src/api/api_arith.cc | 17 ++++++--- src/arithmetic/const_int_bound.cc | 10 ++--- src/arithmetic/modular_set.cc | 9 +++-- 5 files changed, 86 insertions(+), 55 deletions(-) diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 9a8d9d372956..6eec767611e0 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -48,11 +48,7 @@ namespace arith { // Forward declare Analyzer class Analyzer; -/*! - * \brief reference class to ConstIntBoundNode - * \sa ConstIntBoundNode - */ -class ConstIntBound; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. @@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node { v->Visit("max_value", &max_value); } - TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value); - /*! \brief Number to represent +inf */ static const constexpr int64_t kPosInf = std::numeric_limits::max(); /*! @@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); }; -TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode); +/*! + * \brief reference class to ConstIntBoundNode + * \sa ConstIntBoundNode + */ +class ConstIntBound : public NodeRef { + public: + /*! + * \brief constructor by fields. + * \param min_value The mininum value. + * \param max_value The maximum value. + */ + TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value); + + static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; + TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode); +}; /*! * \brief Analyzer to get constant integer bound over expression. @@ -133,11 +143,6 @@ class ConstIntBoundAnalyzer { Impl* impl_; }; -/*! - * \brief reference of ModularSetNode - * \sa ModularSetNode - */ -class ModularSet; /*! * \brief Range of a linear integer function. * Use to do specify the possible index values. @@ -162,13 +167,20 @@ class ModularSetNode : public Node { v->Visit("base", &base); } - TVM_DLL static ModularSet make(int64_t coeff, int64_t base); - static constexpr const char* _type_key = "arith.ModularSet"; TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); }; -TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode); +/*! + * \brief reference of ModularSetNode + * \sa ModularSetNode + */ +class ModularSet : public NodeRef { + public: + TVM_DLL ModularSet(int64_t coeff, int64_t base); + + TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode); +}; /*! * \brief Analyzer to get modular information over expression. diff --git a/include/tvm/base.h b/include/tvm/base.h index ae2d91ff8523..049a427ffce8 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -39,21 +39,24 @@ using ::tvm::Node; using ::tvm::NodeRef; using ::tvm::AttrVisitor; -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ - }; \ +/*! + * \brief Macro to define common node ref methods. + * \param TypeName The name of the NodeRef. + * \param BaseTypeName The Base type. + * \param NodeName The node container type. + */ +#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ + TypeName() {} \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() const { return this->defined(); } \ + using ContainerType = NodeName; /*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. + * \brief Macro to define CopyOnWrite function in a NodeRef. + * \param NodeName The Type of the Node. * * CopyOnWrite will generate a unique copy of the internal node. * The node will be copied if it is referenced by multiple places. @@ -70,25 +73,33 @@ using ::tvm::AttrVisitor; * * \endcode */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - inline NodeName* CopyOnWrite() { \ +#define TVM_DEFINE_NODE_REF_COW(NodeName) \ + NodeName* CopyOnWrite() { \ CHECK(node_ != nullptr); \ if (!node_.unique()) { \ NodePtr n = make_node(*(operator->())); \ NodePtr(std::move(n)).swap(node_); \ } \ return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ - }; + } +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ::tvm::NodeRef { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ + }; \ + +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_NODE_REF_COW(NodeName); \ + }; /*! * \brief save the node as well as all the node it depends on as json. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index fce73aabf6a7..55a706420f06 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound") TVM_REGISTER_API("arith.DomainTouched") .set_body_typed(DomainTouched); - TVM_REGISTER_API("_IntervalSetGetMin") .set_body_method(&IntSet::min); @@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing") TVM_REGISTER_API("_IntSetIsEverything") .set_body_method(&IntSet::is_everything); +ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { + return ConstIntBound(min_value, max_value); +} + TVM_REGISTER_API("arith._make_ConstIntBound") -.set_body_typed(ConstIntBoundNode::make); +.set_body_typed(MakeConstIntBound); + +ModularSet MakeModularSet(int64_t coeff, int64_t base) { + return ModularSet(coeff, base); +} TVM_REGISTER_API("arith._make_ModularSet") -.set_body_typed(ModularSetNode::make); +.set_body_typed(MakeModularSet); TVM_REGISTER_API("arith._CreateAnalyzer") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index bfd06c8ba255..72b85084d59d 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -34,12 +34,12 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); -ConstIntBound ConstIntBoundNode::make( +ConstIntBound::ConstIntBound( int64_t min_value, int64_t max_value) { auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - return ConstIntBound(node); + node_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl : std::vector additional_info_; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. - static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; - static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; + static const constexpr int64_t kPosInf = ConstIntBound::kPosInf; static_assert(-kNegInf == kPosInf, "invariant of inf"); // internal helper functions /*! @@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl : ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { Entry ret = impl_->VisitExpr(expr); - return ConstIntBoundNode::make(ret.min_value, ret.max_value); + return ConstIntBound(ret.min_value, ret.max_value); } void ConstIntBoundAnalyzer::Update(const Var& var, diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 7701e04844fa..57e82943b84c 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -35,11 +35,12 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ModularSetNode); -ModularSet ModularSetNode::make(int64_t coeff, int64_t base) { +ModularSet::ModularSet(int64_t coeff, int64_t base) { auto node = make_node(); node->coeff = coeff; node->base = base; - return ModularSet(node); + // finish construction. + node_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl : * \return Bound that represent everything dtype can represent. */ static Entry Nothing() { - return Entry(0, 1); + return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { Entry ret = impl_->VisitExpr(expr); - return ModularSetNode::make(ret.coeff, ret.base); + return ModularSet(ret.coeff, ret.base); } void ModularSetAnalyzer::Update(const Var& var, From 4427d34952fbd4aed8eda3286510d5d02d3c860c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 23 May 2019 10:23:12 -0700 Subject: [PATCH 040/176] Modified pick best to accumulate the best configurations from both the input and output file. (#3225) --- python/tvm/autotvm/record.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 4c0f98347d4b..14efb7bd9239 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -25,6 +25,8 @@ import pickle import json import time +import os +import itertools from collections import OrderedDict from .. import build, lower, target as _target @@ -238,6 +240,8 @@ def pick_best(in_file, out_file): """ Pick best entries from a file and store it to another file. This distill the useful log entries from a large log file. + If out_file already exists, the best entries from both + in_file and out_file will be saved. Parameters ---------- @@ -246,7 +250,12 @@ def pick_best(in_file, out_file): out_file: str or file The filename of output """ - best_context = ApplyHistoryBest(load_from_file(in_file)) + context = load_from_file(in_file) + if os.path.isfile(out_file): + out_context = load_from_file(out_file) + context = itertools.chain(context, out_context) + context, context_clone = itertools.tee(context) + best_context = ApplyHistoryBest(context) best_set = set() for v in best_context.best_by_model.values(): @@ -258,7 +267,7 @@ def pick_best(in_file, out_file): logger.info("Extract %d best records from the %s", len(best_set), in_file) fout = open(out_file, 'w') if isinstance(out_file, str) else out_file - for inp, res in load_from_file(in_file): + for inp, res in context_clone: if measure_str_key(inp) in best_set: fout.write(encode(inp, res) + "\n") best_set.remove(measure_str_key(inp)) From 489a805eb4144d12c40cd0564dbaa89ad22dd40f Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 23 May 2019 17:52:44 -0700 Subject: [PATCH 041/176] [LINT] handle more file types in ASF header (#3235) * Update add_asf_header.py * Update add_asf_header.py --- tests/lint/add_asf_header.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/lint/add_asf_header.py b/tests/lint/add_asf_header.py index 7e0352f4bc2c..1afb3a57b2f1 100644 --- a/tests/lint/add_asf_header.py +++ b/tests/lint/add_asf_header.py @@ -117,6 +117,7 @@ """.strip() FMT_MAP = { + "sh" : header_pystyle, "cc" : header_cstyle, "h" : header_cstyle, "py" : header_pystyle, @@ -128,6 +129,7 @@ "cmake" : header_pystyle, "rst" : header_rststyle, "gradle" : header_groovystyle, + "xml": header_mdstyle, } def add_header(fname, header): @@ -142,8 +144,23 @@ def add_header(fname, header): return with open(fname, "w") as outfile: - outfile.write(header + "\n\n") - outfile.write(orig) + skipline = False + lines = orig.split('\n') + ext = os.path.splitext(fname)[1][1:] + if ext == 'sh' and lines[0][:2] == '#!': + skipline = True + elif ext == 'xml' and lines[0][:2] == ' Date: Fri, 24 May 2019 09:29:14 -0700 Subject: [PATCH 042/176] [C++][API] Consistent RAII scoping API. (#3231) --- include/tvm/arithmetic.h | 25 +++-- include/tvm/base.h | 44 +++++++++ include/tvm/build_module.h | 131 +++++++++------------------ python/tvm/build_module.py | 2 +- python/tvm/target.py | 2 +- src/api/api_arith.cc | 4 +- src/arithmetic/analyzer.cc | 17 +++- src/arithmetic/rewrite_simplify.cc | 8 +- src/arithmetic/stmt_simplify.cc | 10 +- src/codegen/build_module.cc | 79 +++++++++------- src/codegen/codegen_aocl.cc | 6 +- src/codegen/codegen_vhls.cc | 6 +- src/codegen/llvm/codegen_llvm.cc | 6 +- src/codegen/spirv/codegen_spirv.cc | 6 +- src/relay/backend/build_module.cc | 10 +- src/relay/backend/compile_engine.cc | 8 +- src/relay/backend/vm/compiler.cc | 4 +- src/relay/pass/fold_constant.cc | 8 +- src/relay/pass/partial_eval.cc | 4 +- tests/cpp/build_module_test.cc | 12 +-- tests/cpp/relay_build_module_test.cc | 6 +- topi/src/topi.cc | 6 +- 22 files changed, 214 insertions(+), 190 deletions(-) diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 6eec767611e0..600e3c565358 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -290,14 +290,14 @@ class CanonicalSimplifier { }; /*! - * \brief A RAII constraint context. + * \brief Constraint context. * * \code * * Var("x"); * arith::Analyzer analyzer; * { - * arith::ConstraintContext cctx(&analyzer, x % 3 == 0); + * With scope(&analyzer, x % 3 == 0); * CHECK_EQ(analyzer.modular_set(x)->coeff, 3); * } * // constraint no longer in effect. @@ -306,19 +306,24 @@ class CanonicalSimplifier { * \endcode */ class ConstraintContext { - public: + private: + // declare friend to enable with. + friend class With; /*! * \brief Construct a constraint context. * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION; - /*! \brief destructor */ - ~ConstraintContext() DMLC_THROW_EXCEPTION { - exit_(); - } - - private: + ConstraintContext(Analyzer* analyzer, Expr constraint) + : analyzer_(analyzer), constraint_(constraint) {} + // enter the scope. + void EnterWithScope(); + // exit the scope. + void ExitWithScope(); + /*! \brief The analyzer */ + Analyzer* analyzer_; + /*! \brief The constraint */ + Expr constraint_; /*! \brief function to be called in recovery */ std::function exit_; }; diff --git a/include/tvm/base.h b/include/tvm/base.h index 049a427ffce8..f358f7f5d447 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -101,6 +101,50 @@ using ::tvm::AttrVisitor; TVM_DEFINE_NODE_REF_COW(NodeName); \ }; +/*! + * \brief RAII wrapper function to enter and exit a context object + * similar to python's with syntax. + * + * \code + * // context class + * class MyContext { + * private: + * friend class With; + MyContext(arguments); + * void EnterWithScope(); + * void ExitWithScope(); + * }; + * + * { + * With scope(arguments); + * // effect take place. + * } + * \endcode + * + * \tparam ContextType Type of the context object. + */ +template +class With { + public: + /*! + * \brief constructor. + * Enter the scope of the context. + */ + template + explicit With(Args&& ...args) + : ctx_(std::forward(args)...) { + ctx_.EnterWithScope(); + } + /*! \brief destructor, leaves the scope of the context. */ + ~With() DMLC_THROW_EXCEPTION { + ctx_.ExitWithScope(); + } + + private: + /*! \brief internal context type. */ + ContextType ctx_; +}; + /*! * \brief save the node as well as all the node it depends on as json. * This can be used to serialize any TVM object diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 7fb456c823a7..187a74552241 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -37,7 +37,7 @@ namespace tvm { /*! * \brief Container for target device information. -* Use target::llvm, target::cuda etc functions instead of constructing directly. +* Use target::llvm, target::cuda etc functions instead of constructing directly. */ class TargetNode : public Node { public: @@ -89,65 +89,47 @@ class TargetNode : public Node { mutable std::string str_repr_; }; +/*! \brief reference cpass to the target. */ class Target : public NodeRef { public: Target() {} explicit Target(NodePtr n) : NodeRef(n) {} - /*! * \brief Create a Target given a string * \param target_str the string to parse */ - TVM_DLL static Target create(const std::string& target_str); - - /*! - * \brief Push a new target context onto the thread local stack. The Target on top of - * the stack is used to determine which specialization to use when invoking a GenericFunc. - * \param target The target to set as the current context. - */ - TVM_DLL static void EnterTargetScope(const tvm::Target& target); - - /*! - * \brief Pop a target off the thread local context stack, restoring the previous target - * as the current context. - */ - TVM_DLL static void ExitTargetScope(); - + TVM_DLL static Target Create(const std::string& target_str); /*! - * \brief Get the current target context from thread local storage. - * \param allow_not_defined If the context stack is empty and this is set to true, an - * undefined Target will be returned. Otherwise, an empty context stack will cause a - * runtime error. - * \return The target that is the current context. The target may not be defined if - * allow_not_defined is true. - */ - TVM_DLL static tvm::Target current_target(bool allow_not_defined = true); + * \brief Get the current target context from thread local storage. + * \param allow_not_defined If the context stack is empty and this is set to true, an + * undefined Target will be returned. Otherwise, an empty context stack will cause a + * runtime error. + * \return The target that is the current context. The target may not be defined if + * allow_not_defined is true. + */ + TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - inline const TargetNode* operator->() const { + const TargetNode* operator->() const { return static_cast(node_.get()); } using ContainerType = TargetNode; -}; - -/*! - * \brief RAII container to provide a scoped target context. Pushes a target onto the - * context stack when constructed, and pops it when destructed. - */ -struct TargetContext { + class Internal; + private: + // enable with syntax. + friend class Internal; + friend class With; /*! - * \brief Enter a new target context. The given target becomes the new current context. - * When the TargetContext is destructed, the previous context is restored. - * \param target The target to set as the new current context. + * \brief Push a new target context onto the thread local stack. + * The Target on top of the stack is used to determine which + * specialization to use when invoking a GenericFunc. */ - explicit TargetContext(const tvm::Target& target) { - Target::EnterTargetScope(target); - } - - /*! \brief Destructor. Pops the context off the thread local stack. */ - ~TargetContext() { - Target::ExitTargetScope(); - } + TVM_DLL void EnterWithScope(); + /*! + * \brief Pop a target off the thread local context stack, + * restoring the previous target as the current context. + */ + TVM_DLL void ExitWithScope(); }; /*! \brief This namespace provides functions to construct Target instances */ @@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector& options = } // namespace target -class BuildConfig; - /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class BuildConfigNode : public Node { public: /*! @@ -271,69 +251,48 @@ class BuildConfigNode : public Node { }; /*! -* \brief Container for build configuration options -*/ + * \brief Build configuration for compilations. + */ class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} - const BuildConfigNode* operator->() const { return static_cast(node_.get()); } - BuildConfigNode* operator->() { return static_cast(node_.get()); } - /*! - * \brief Push a new BuildConfig context onto the thread local stack. - * \param build_config The configuration to set as the current context. + * \brief Construct a BuildConfig containing a empty build config node. + * \return The new BuildConfig */ - TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config); - - /*! - * \brief Pop a build config off the thread local context stack, restoring the previous - * configuration as the current context. - */ - TVM_DLL static void ExitBuildConfigScope(); - + TVM_DLL static BuildConfig Create(); /*! * \brief Get the current BuildConfig context from thread local storage, or a default * configuration if a BuildConfig scope has not been entered. * \return The configuration that is the current context. */ - TVM_DLL static tvm::BuildConfig Current(); + TVM_DLL static BuildConfig Current(); using ContainerType = BuildConfigNode; -}; + class Internal; -/*! - * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the - * context stack when constructed, and pops it when destructed. - */ -struct BuildConfigContext { + private: + // Enable with syntax. + friend class With; /*! - * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current - * context. When the BuildConfigContext is destructed, the previous context is restored. - * \param build_config The BuildConfig to set as the new current context. + * \brief Push a new BuildConfig context onto the thread local stack. */ - explicit BuildConfigContext(const tvm::BuildConfig& build_config) { - BuildConfig::EnterBuildConfigScope(build_config); - } + TVM_DLL void EnterWithScope(); - /*! \brief Destructor. Pops the context off the thread local stack. */ - ~BuildConfigContext() { - BuildConfig::ExitBuildConfigScope(); - } + /*! + * \brief Pop a build config off the thread local context stack, + * restoring the previous configuration as the current context. + */ + TVM_DLL void ExitWithScope(); }; -/*! -* \brief Construct a BuildConfig containing a new BuildConfigNode -* \return The new BuildConfig -*/ -TVM_DLL BuildConfig build_config(); - /*! * \brief Build a LoweredFunc given a schedule, args and binds * \param sch The schedule to lower. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index a28ab98fb60e..76170a844db1 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -187,7 +187,7 @@ def __enter__(self): def __exit__(self, ptype, value, trace): if self.dump_pass_ir: BuildConfig._dump_ir.exit() - _api_internal._ExitBuildConfigScope() + _api_internal._ExitBuildConfigScope(self) def __setattr__(self, name, value): if name in BuildConfig._node_defaults: diff --git a/python/tvm/target.py b/python/tvm/target.py index eff0088b37ce..828fff8e228c 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -133,7 +133,7 @@ def __enter__(self): return self def __exit__(self, ptype, value, trace): - _api_internal._ExitTargetScope() + _api_internal._ExitTargetScope(self) @register_node diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 55a706420f06..4d5d8bdf58d3 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { // can't use make_shared due to noexcept(false) decl in destructor, // see https://stackoverflow.com/a/43907314 - auto ctx = - std::shared_ptr(new ConstraintContext(self.get(), args[0])); + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 420d6f9c1d0d..bd8c7005f458 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) { // skip rewrite simplify } -ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) { + +void ConstraintContext::EnterWithScope() { + CHECK(exit_ == nullptr); // entering the scope. - auto f0 = analyzer->const_int_bound.EnterConstraint(constraint); - auto f1 = analyzer->modular_set.EnterConstraint(constraint); + auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); + auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); // recovery function. exit_ = [f0, f1]() { if (f1 != nullptr) f1(); @@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) }; } +void ConstraintContext::ExitWithScope() { + CHECK(exit_ != nullptr); + exit_(); +} + bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { if (const auto* ptr = expr.as()) { return ptr->value > lower_bound; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 58d2b83a223a..0de2a2535ae7 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) { Expr cond = Mutate(op->condition); Expr true_value, false_value; { - ConstraintContext constraint(parent_, cond); + With constraint(parent_, cond); true_value = Mutate(op->true_value); } { - ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + With constraint(parent_, Mutate(Not::make(cond))); false_value = Mutate(op->false_value); } if (is_zero(cond)) { @@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) { Expr cond = Mutate(op->args[0]); Expr true_value, false_value; { - ConstraintContext constraint(parent_, cond); + With constraint(parent_, cond); true_value = Mutate(op->args[1]); } { - ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + With constraint(parent_, Mutate(Not::make(cond))); false_value = Mutate(op->args[2]); } if (is_zero(cond)) { diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index c793214b92f4..403187eb39fd 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator { Expr condition = this->Mutate(op->condition); Stmt then_case, else_case; { - ConstraintContext ctx(&analyzer_, condition); + With ctx(&analyzer_, condition); then_case = this->Mutate(op->then_case); } if (op->else_case.defined()) { - ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition))); + With ctx(&analyzer_, Mutate(Not::make(condition))); else_case = this->Mutate(op->else_case); } if (is_one(condition)) return then_case; @@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator { Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { Expr condition = this->Mutate(op->condition); Expr message = this->Mutate(op->message); - ConstraintContext ctx(&analyzer_, condition); + With ctx(&analyzer_, condition); Stmt body = this->Mutate(op->body); if (condition.same_as(op->condition) && diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index ac6b797d9683..834b4eea7e3f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Compile executable modules. * \file build_module.cc */ @@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate") TVM_REGISTER_API("_TargetFromString") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; - - *ret = Target::create(target_str); + *ret = Target::Create(target_str); }); std::vector TargetNode::keys() const { @@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) { return ""; } -Target Target::create(const std::string& target_str) { +Target Target::Create(const std::string& target_str) { if (target_str.length() == 0) { LOG(ERROR) << "target_str must not be empty"; } @@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) { struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ std::stack context_stack; - - TVMTargetThreadLocalEntry() { - } }; /*! \brief Thread local store to hold the Target context stack. */ typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; -void Target::EnterTargetScope(const tvm::Target& target) { +void Target::EnterWithScope() { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); - entry->context_stack.push(target); + entry->context_stack.push(*this); } -void Target::ExitTargetScope() { +void Target::ExitWithScope() { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } -tvm::Target Target::current_target(bool allow_not_defined) { +tvm::Target Target::Current(bool allow_not_defined) { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); @@ -574,7 +571,7 @@ runtime::Module build(const Map>& inputs, const BuildConfig& config) { Map> updated_input; for (const auto& it : inputs) { - auto target = Target::create(it.first); + auto target = Target::Create(it.first); updated_input.Set(target, it.second); } return build(updated_input, target_host, config); @@ -589,33 +586,35 @@ runtime::Module build(const Array& funcs, return build(inputs, target_host, config); } -BuildConfig build_config() { +BuildConfig BuildConfig::Create() { return BuildConfig(make_node()); } /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMBuildConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ - tvm::BuildConfig default_config; + BuildConfig default_config; /*! \brief The current build config context */ - std::stack context_stack; + std::stack context_stack; TVMBuildConfigThreadLocalEntry() : - default_config(build_config()) { + default_config(BuildConfig::Create()) { } }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMBuildConfigThreadLocalStore; -void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) { +void BuildConfig::EnterWithScope() { TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); - entry->context_stack.push(build_config); + entry->context_stack.push(*this); } -void BuildConfig::ExitBuildConfigScope() { +void BuildConfig::ExitWithScope() { TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } @@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { auto node = static_cast(node_.get()); - auto target = Target::current_target(true); + auto target = Target::Current(true); PackedFunc func; if (target.defined()) { @@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig") *ret = BuildConfig::Current(); }); +class BuildConfig::Internal { + public: + static void EnterScope(BuildConfig target) { + target.EnterWithScope(); + } + static void ExitScope(BuildConfig target) { + target.ExitWithScope(); + } +}; + TVM_REGISTER_API("_EnterBuildConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig target = args[0]; - BuildConfig::EnterBuildConfigScope(target); - }); +.set_body_typed(BuildConfig::Internal::EnterScope); TVM_REGISTER_API("_ExitBuildConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig::ExitBuildConfigScope(); - }); +.set_body_typed(BuildConfig::Internal::ExitScope); TVM_REGISTER_API("_BuildConfigSetAddLowerPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc") TVM_REGISTER_API("_GetCurrentTarget") .set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; - *ret = Target::current_target(allow_not_defined); + *ret = Target::Current(allow_not_defined); }); +class Target::Internal { + public: + static void EnterScope(Target target) { + target.EnterWithScope(); + } + static void ExitScope(Target target) { + target.ExitWithScope(); + } +}; + TVM_REGISTER_API("_EnterTargetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Target target = args[0]; - Target::EnterTargetScope(target); - }); +.set_body_typed(Target::Internal::EnterScope); TVM_REGISTER_API("_ExitTargetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Target::ExitTargetScope(); - }); +.set_body_typed(Target::Internal::ExitScope); } // namespace tvm diff --git a/src/codegen/codegen_aocl.cc b/src/codegen/codegen_aocl.cc index 6f899cbb0b53..03b9b6869d17 100644 --- a/src/codegen/codegen_aocl.cc +++ b/src/codegen/codegen_aocl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array funcs, std::string target_str, std::string cmd = "aoc aocl.cl"; // AOCL supports fp64. cmd += " -Dcl_khr_fp64"; - Target target = Target::create(target_str); + Target target = Target::Create(target_str); if (target->device_name != "") { cmd += " -board=" + target->device_name; } diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index a18312fe6af5..4d86cc5b4b00 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { std::string xclbin; if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { - Target target = Target::create(target_str); + Target target = Target::Create(target_str); xclbin = (*f)(kernel_info, target->device_name).operator std::string(); } else { LOG(FATAL) << "Cannot compile Vivado HLS code."; diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index bedcdc79ff1f..1e56583a37fd 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { } void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { - arith::ConstraintContext cctx(analyzer_.get(), op->condition); + With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index e6fc0088dc81..fd113ca4614a 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { } void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { - arith::ConstraintContext cctx(analyzer_.get(), op->condition); + With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 8a0c32fc6684..3b1491072d25 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (targets.size() == 1) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); for (const auto& kv : targets) { - TargetContext tctx(kv.second); + With tctx(kv.second); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); } } else { @@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode { */ Target CreateDefaultTarget(int device_type) { std::string name = runtime::DeviceName(device_type); - if (name == "cpu") return Target::create("llvm"); - if (name == "gpu") return Target::create("cuda"); - return Target::create(name); + if (name == "cpu") return Target::Create("llvm"); + if (name == "gpu") return Target::Create("cuda"); + return Target::Create(name); } /*! * \brief Update the target and fallback device required for heterogeneous @@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode { const RelayBuildConfig& cfg, const std::unordered_map ¶ms) { // convert - tvm_cfg_ = build_config(); + tvm_cfg_ = BuildConfig::Create(); TargetsMap device_target; if (targets_.size() > 1) { device_target = UpdateHeterogeneousInputs(targets_, cfg); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index a824c457107a..f11dd2875b80 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_[key] = value; } // Enforce use the target. - TargetContext target_ctx(key->target); + With target_scope(key->target); CHECK(!value->cached_func.defined()); auto spair = CreateSchedule(key->source_func, key->target); @@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_node->funcs = (*f)( spair.first, all_args, cache_node->func_name, key->source_func); } else { - tvm::BuildConfig bcfg = tvm::build_config(); + tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 97f03c629cb7..602e92759624 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor { // Next generate the invoke instruction. CHECK(func->IsPrimitive()); - auto target = Target::create("llvm"); + auto target = Target::Create("llvm"); auto key = CCacheKeyNode::make(func, target); auto cfunc = engine->Lower(key); // TODO(jroesch): support lowered funcs for multiple targets @@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, runtime::Module mod; if (lowered_funcs.size() > 0) { // TODO(@jroesch): we need to read target from build config - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); if (const auto* f = runtime::Registry::Get("relay.backend.build")) { mod = (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), target); } else { diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 45aa449e72ab..c085d80d06e2 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - BuildConfigContext fresh_build_ctx(build_config()); + With fresh_build_ctx(BuildConfig::Create()); return ConstantFolder(CreateInterpreter( Module(nullptr), ctx, target)).Mutate(expr); diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 5349532ca697..ad861743dfd5 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -375,10 +375,10 @@ DLContext CPUContext() { } FInterpreter CPUInterpreter() { - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - BuildConfigContext fresh_build_ctx(build_config()); + With fresh_build_ctx(BuildConfig::Create()); return CreateInterpreter(Module(nullptr), CPUContext(), target); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 393714d8f636..6dbd78e9566d 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -50,14 +50,14 @@ TEST(BuildModule, Basic) { auto args = Array({ A, B, C }); std::unordered_map binds; - auto config = build_config(); + auto config = BuildConfig::Create(); auto target = target::llvm(); auto lowered = lower(s, args, "func", binds, config); auto module = build(lowered, target, Target(), config); - auto mali_target = Target::create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); - CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); + auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); + CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); } TEST(BuildModule, Heterogeneous) { @@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) { auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); auto s2 = create_schedule({elemwise_sub->op}); - auto config = build_config(); + auto config = BuildConfig::Create(); auto args1 = Array({A, B, elemwise_add}); auto args2 = Array({copy, C, elemwise_sub}); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index a1ab29959127..3f46eed9f10e 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -75,7 +75,7 @@ TEST(Relay, BuildModule) { auto json_f = build_mod.GetFunction("get_graph_json", false); auto mod_f = build_mod.GetFunction("get_module", false); Map targets; - Target llvm_tgt = Target::create("llvm"); + Target llvm_tgt = Target::Create("llvm"); targets.Set(0, llvm_tgt); build_f(func, targets, llvm_tgt); std::string json = json_f(); diff --git a/topi/src/topi.cc b/topi/src/topi.cc index d3e0bc938f7c..57a2743ae6d0 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) { TVM_REGISTER_GLOBAL("topi.TEST_create_target") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tvm::Target::create(args[0]); + *rv = tvm::Target::Create(args[0]); }); /* Ops from broadcast.h */ @@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function< */ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - auto target = Target::current_target(false); + auto target = Target::Current(false); Array outs; NodeRef argNodeRef = args[0]; if (argNodeRef->type_index() == outs->type_index()) { @@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function Date: Fri, 24 May 2019 12:05:00 -0700 Subject: [PATCH 043/176] [Relay][Transform] merge PassContext and BuildConfig (#3234) --- docs/api/python/relay/build_module.rst | 8 - docs/api/python/relay/transform.rst | 45 +++ include/tvm/relay/transform.h | 92 ++++- python/tvm/relay/__init__.py | 3 +- python/tvm/relay/build_module.py | 98 +---- python/tvm/relay/quantize/quantize.py | 14 +- python/tvm/relay/transform.py | 125 +++++-- src/relay/pass/pass_manager.cc | 372 +++++++++++++------ tests/python/frontend/coreml/test_forward.py | 4 +- tests/python/frontend/keras/test_forward.py | 2 +- tutorials/frontend/from_tflite.py | 2 +- 11 files changed, 501 insertions(+), 264 deletions(-) create mode 100644 docs/api/python/relay/transform.rst diff --git a/docs/api/python/relay/build_module.rst b/docs/api/python/relay/build_module.rst index 28dadea21e78..26164bf1ade9 100644 --- a/docs/api/python/relay/build_module.rst +++ b/docs/api/python/relay/build_module.rst @@ -22,17 +22,9 @@ tvm.relay.build_module .. autofunction:: tvm.relay.build_module.build -.. autofunction:: tvm.relay.build_module.build_config - .. autofunction:: tvm.relay.build_module.optimize .. autofunction:: tvm.relay.build_module.create_executor -.. autoclass:: tvm.relay.build_module.BuildConfig - :members: - -.. autofunction:: tvm.relay.build_module.build_config - :members: - .. autoclass:: tvm.relay.build_module.GraphExecutor :members: diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst new file mode 100644 index 000000000000..4eb7f9d8fea7 --- /dev/null +++ b/docs/api/python/relay/transform.rst @@ -0,0 +1,45 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relay.transform +---------------------- + +.. automodule:: tvm.relay.transform + +.. autofunction:: tvm.relay.transform.build_config + +.. autofunction:: tvm.relay.transform.module_pass + +.. autofunction:: tvm.relay.transform.function_pass + +.. autoclass:: tvm.relay.transform.Pass + :members: + +.. autoclass:: tvm.relay.transform.PassInfo + :members: + +.. autoclass:: tvm.relay.transform.PassContext + :members: + +.. autoclass:: tvm.relay.transform.ModulePass + :members: + +.. autoclass:: tvm.relay.transform.FunctionPass + :members: + +.. autoclass:: tvm.relay.transform.Sequential + :members: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ba25483dfbb2..5123f3a3dcf3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -56,11 +56,13 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ +#include #include #include #include #include #include +#include #include namespace tvm { @@ -83,18 +85,69 @@ class PassContextNode : public RelayNode { */ ErrorReporter err_reporter; + /*! \brief The default optimization level. */ + int opt_level{2}; + + /*! \brief CPU is the default fallback device for heterogeneous execution. */ + int fallback_device{static_cast(kDLCPU)}; + + /*! \brief The list of required passes. */ + tvm::Array required_pass; + /*! \brief The list of disabled passes. */ + tvm::Array disabled_pass; + PassContextNode() = default; void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("fallback_device", &fallback_device); + v->Visit("required_pass", &required_pass); + v->Visit("disabled_pass", &disabled_pass); } - TVM_DLL static PassContext make(); - static constexpr const char* _type_key = "relay.PassContext"; TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); }; -TVM_DEFINE_NODE_REF(PassContext, PassContextNode) +class PassContext : public NodeRef { + public: + PassContext() {} + explicit PassContext(tvm::NodePtr n) : NodeRef(n) {} + + /* + * \brief Constructor of a `PassContext` object. + * + * \param opt_level The optimization level that will be applied. + * \param fallback_device The fallback device used for heterogeneous + * execution. + * \param required_pass The passes that are required for a context to execute + * other passes. + * \param required_pass The passes that will be disabled during the + * optimization under a context. + */ + TVM_DLL PassContext(int opt_level, + int fallback_device, + tvm::Array required_pass, + tvm::Array disabled_pass); + + // Get the currently used pass context. + TVM_DLL static PassContext Current(); + + const PassContextNode* operator->() const; + + using ContainerType = PassContextNode; + class Internal; + + private: + // The entry of a pass context scope. + TVM_DLL void EnterWithScope(); + // The exit of a pass context scope. + TVM_DLL void ExitWithScope(); + + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class tvm::With; +}; /* * \brief The meta data of a pass. @@ -150,20 +203,28 @@ class PassNode : public RelayNode { virtual PassInfo Info() const = 0; /*! - * \brief Set the context information for a pass. + * \brief Execute the optimization pass using a functor. This functor + * internally uses a current pass context. + * + * \param mod The module that an optimization pass runs on. * - * \param pass_ctx The context information for a certain pass. + * \return The updated module. */ - virtual void SetContext(const PassContext& pass_ctx) = 0; + Module operator()(const Module& mod) const { + return this->operator()(mod, PassContext::Current()); + } /*! - * \brief Execute the optimization pass using a functor. + * \brief Execute the optimization pass using a functor under a given pass context. * * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that will be used to help the execution of + * optimizations. * * \return The updated module. */ - virtual Module operator()(const Module& mod) const = 0; + virtual Module operator()(const Module& mod, + const PassContext& pass_ctx) const = 0; void VisitAttrs(tvm::AttrVisitor* v) override {} @@ -189,13 +250,22 @@ class Sequential : public Pass { public: /*! * \brief The constructor of `Sequential`. + * * \param passes The passes to apply. * \param pass_info The pass metadata. - * \param disabled The passes that will not be applied. */ TVM_DLL Sequential(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled); + PassInfo pass_info); +/*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param name The name of a sequential pass. It's defaulted to "sequential". + * This allows users to only provide a list of passes and execute them + * under a given context. + */ + TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); + Sequential() = default; explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index d832c8988795..1c8f5d6ceed3 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -26,7 +26,8 @@ from . import adt from . import ir_pass from . import transform -from .build_module import build, build_config, create_executor +from .build_module import build, create_executor +from .transform import build_config from . import prelude from . import parser from . import debug diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d0ad78fee67f..6cee393d5f91 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -28,81 +28,10 @@ from . import ir_pass from . import ty as _ty from . import expr as _expr +from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor -class BuildConfig(object): - """Configuration scope to set a build config option. - - Parameters - ---------- - kwargs - Keyword arguments of configurations to set. - """ - current = None - defaults = { - "opt_level": 2, - "add_pass": None, - "disable_pass": None, - "fallback_device": None, - } - - def __init__(self, **kwargs): - self._old_scope = None - for k, _ in kwargs.items(): - if k not in BuildConfig.defaults: - raise ValueError("invalid argument %s, candidates are %s" % - (k, BuildConfig.defaults.keys())) - self._attr = kwargs - - def __getattr__(self, name): - if name not in self._attr: - return BuildConfig.defaults[name] - return self._attr[name] - - def __enter__(self): - # pylint: disable=protected-access - self._old_scope = BuildConfig.current - attr = BuildConfig.current._attr.copy() - attr.update(self._attr) - self._attr = attr - BuildConfig.current = self - return self - - def __exit__(self, ptype, value, trace): - assert self._old_scope - BuildConfig.current = self._old_scope - - -BuildConfig.current = BuildConfig() - - -def build_config(**kwargs): - """Configure the build behavior by setting config variables. - - Parameters - ---------- - opt_level: int, default=2 - Optimization level. See OPT_PASS_LEVEL for level of each pass. - - add_pass: set of str - Optimization pass to be added regardless of optimization level. - - disable_pass: set of str - Optimization pass to be disabled during optimization. - - fallback_device : str or tvm.TVMContext - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. - - Returns - ------- - config: BuildConfig - The build configuration - """ - return BuildConfig(**kwargs) - - def _update_target(target): target = target if target else _target.current_target() if target is None: @@ -189,7 +118,7 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params def _setup_build_config(self, params): - cfg = BuildConfig.current + cfg = _transform.PassContext.current() # Set opt_level. self.set_opt_level(cfg.opt_level) @@ -199,24 +128,24 @@ def _setup_build_config(self, params): self.set_fallback_device(cfg.fallback_device) # Add required passes. - if cfg.add_pass: + if cfg.required_pass: passes = set() - if isinstance(cfg.add_pass, (list, tuple, set)): - passes = set(cfg.add_pass) + if isinstance(cfg.required_pass, (list, tuple, set)): + passes = set(cfg.required_pass) else: raise TypeError("add_pass must be list, tuple, or set, but " + - "got {}".format(type(cfg.add_pass))) + "got {}".format(type(cfg.required_pass))) for pass_name in passes: self.add_pass(pass_name) # Add disabled passes. - if cfg.disable_pass: + if cfg.disabled_pass: passes = set() - if isinstance(cfg.disable_pass, (list, tuple, set)): - passes = set(cfg.disable_pass) + if isinstance(cfg.disabled_pass, (list, tuple, set)): + passes = set(cfg.disabled_pass) else: raise TypeError("disable_pass must be list, tuple, or set, " + - "but got {}".format(type(cfg.disable_pass))) + "but got {}".format(type(cfg.disabled_pass))) for pass_name in passes: self.disable_pass(pass_name) @@ -287,12 +216,11 @@ def set_fallback_device(self, fallback_device): fallback_device : str or tvm.TVMContext The fallback device used for heterogeneous execution. """ - if isinstance(fallback_device, str): + if isinstance(fallback_device, (int, str)): fallback_device = _nd.context(fallback_device) if not isinstance(fallback_device, TVMContext): - raise TypeError("fallback_device is expected to be str " + - "TVMContext, or dict of device name to target, " + - "but received: {}".format(type(fallback_device))) + raise TypeError("fallback_device is expected to be str, int, or " + + "TVMContext but received: {}".format(type(fallback_device))) self._set_fallback_device(fallback_device.device_type) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 7fd0099e64a2..2423e76d308a 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from . import _quantize from .. import expr as _expr from .. import ir_pass as _ir_pass -from .. import build_module as _build +from .. import transform as _transform from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -301,7 +301,7 @@ def optimize(func, params=None): "FoldConstant", "CanonicalizeOps"] - cfg = _build.build_config(add_pass=opt_passes) + cfg = _transform.build_config(required_pass=opt_passes) if params: name_dict = {} @@ -321,25 +321,25 @@ def optimize(func, params=None): bind_dict[arg] = _expr.const(v) func = _expr.bind(func, bind_dict) - if "SimplifyInference" in cfg.add_pass: + if "SimplifyInference" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.simplify_inference(func) - if "FoldConstant" in cfg.add_pass: + if "FoldConstant" in cfg.required_pass: func = _ir_pass.fold_constant(func) - if "FoldScaleAxis" in cfg.add_pass: + if "FoldScaleAxis" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.backward_fold_scale_axis(func) func = _ir_pass.infer_type(func) func = _ir_pass.forward_fold_scale_axis(func) func = _ir_pass.fold_constant(func) - if "CanonicalizeOps" in cfg.add_pass: + if "CanonicalizeOps" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.canonicalize_ops(func) - if "FoldConstant" in cfg.add_pass: + if "FoldConstant" in cfg.required_pass: func = _ir_pass.fold_constant(func) return func diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 877538afea34..a7887c630c76 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -23,8 +23,10 @@ """ import types +from tvm._ffi.runtime_ctypes import TVMContext from . import _transform from .base import RelayNode, register_relay_node +from .. import nd as _nd @register_relay_node @@ -57,10 +59,102 @@ class PassContext(RelayNode): Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter to record the errors of during the optimization, etc. + + opt_level : Optional[int] + The optimization level of this pass. + + fallback_device : Optional[Union[int, str, TVMContext]] + The fallback device type. It is also used as the default device for + operators that are not annotated during heterogeneous execution. + + required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are required by a certain pass. + + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are disabled. """ + def __init__(self, + opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): + if isinstance(fallback_device, str): + fallback_device = _nd.context(fallback_device).device_type + elif isinstance(fallback_device, TVMContext): + fallback_device = fallback_device.device_type + if not isinstance(fallback_device, int): + raise TypeError("required_pass is expected to be the type of " + + "int/str/TVMContext.") + + required = list(required_pass) if required_pass else [] + if not isinstance(required, (list, tuple)): + raise TypeError("required_pass is expected to be the type of " + + "list/tuple/set.") - def __init__(self): - self.__init_handle_by_constructor__(_transform.PassContext) + disabled = list(disabled_pass) if disabled_pass else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled_pass is expected to be the type of " + + "list/tuple/set.") + + self.__init_handle_by_constructor__(_transform.PassContext, opt_level, + fallback_device, required, + disabled) + + def __enter__(self): + _transform.EnterPassContext(self) + return self + + def __exit__(self, ptype, value, trace): + _transform.ExitPassContext(self) + + @staticmethod + def current(): + """Return the current pass context.""" + return _transform.GetCurrentPassContext() + + +def build_config(opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): + """Configure the build behavior by setting config variables. + + Parameters + ---------- + opt_level: int, optional + Optimization level. The optimization pass name and level are as the + following: + + .. code-block:: python + + OPT_PASS_LEVEL = { + "SimplifyInference": 0, + "OpFusion": 1, + "FoldConstant": 2, + "CombineParallelConv2D": 3, + "FoldScaleAxis": 3, + "AlterOpLayout": 3, + "CanonicalizeOps": 3, + "EliminateCommonSubexpr": 3, + } + + fallback_device : int, str, or tvm.TVMContext, optional + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. + + required_pass: set of str, optional + Optimization passes that are required regardless of optimization level. + + disabled_pass: set of str, optional + Optimization passes to be disabled during optimization. + + Returns + ------- + pass_context: PassContext + The pass context for optimizations. + """ + return PassContext(opt_level, fallback_device, required_pass, + disabled_pass) @register_relay_node @@ -70,20 +164,6 @@ class Pass(RelayNode): conveniently interact with the base class. """ - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _transform.SetContext(self, pass_ctx) - @property def info(self): """Get the pass meta.""" @@ -150,32 +230,23 @@ class Sequential(Pass): required : Optional[List[str]] The list of passes that the sequential pass is dependent on. - - disabled : Optional[List[str]] - A list of disabled passes. """ def __init__(self, passes=None, opt_level=2, name="sequential", - required=None, - disabled=None): + required=None): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - required = required if required else [] if not isinstance(required, (list, tuple)): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__(_transform.Sequential, - passes, opt_level, name, required, - disabled) + passes, opt_level, name, required) def module_pass(pass_func=None, opt_level=None, name=None, required=None): diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a105b692aa9d..4bcc0bb39cc4 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -22,8 +22,14 @@ * \file src/relay/pass/pass_manager.cc * \brief Relay pass manager implementation. */ +#include #include #include +#include + +#include +#include +#include namespace tvm { namespace relay { @@ -31,6 +37,98 @@ namespace transform { using tvm::IRPrinter; +/*! + * \brief A data structure to map the names of specific optimizations to + * numeric optimization levels + */ +class OptPassLevel { + public: + /*! + * \brief Get level for an optimization pass + * + * \param key pass name + * \return int level + */ + int operator[](const std::string& key) const { + const auto data = CreateMap(); + auto it = data.find(key); + if (it == data.end()) { + return -1; + } + return it->second; + } + + private: + static const std::unordered_map CreateMap() { + const std::unordered_map m = { + {"SimplifyInference", 0}, + {"OpFusion", 1}, + {"FoldConstant", 2}, + {"CombineParallelConv2D", 3}, + {"FoldScaleAxis", 3}, + {"AlterOpLayout", 3}, + {"CanonicalizeOps", 3}, + {"EliminateCommonSubexpr", 3} + }; + return m; + } +}; + +PassContext::PassContext(int opt_level, int fallback_device, + tvm::Array required_pass, + tvm::Array disabled_pass) { + auto ctx = make_node(); + ctx->opt_level = opt_level; + ctx->fallback_device = fallback_device; + ctx->required_pass = std::move(required_pass); + ctx->disabled_pass = std::move(disabled_pass); + node_ = std::move(ctx); +} + +const PassContextNode* PassContext::operator->() const { + return static_cast(node_.get()); +} + +struct RelayPassContextThreadLocalEntry { + /*! \brief The default pass context. */ + PassContext default_context; + + /*! \brief The current pass context. */ + std::stack context_stack; + + RelayPassContextThreadLocalEntry() { + default_context = PassContext(make_node()); + } +}; + +/*! \brief Thread local store to hold the pass context. */ +typedef dmlc::ThreadLocalStore + RelayPassContextThreadLocalStore; + +void PassContext::EnterWithScope() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + entry->context_stack.push(*this); +} + +void PassContext::ExitWithScope() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); + entry->context_stack.pop(); +} + +PassContext PassContext::Current() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + if (!entry->context_stack.empty()) { + return entry->context_stack.top(); + } else { + return entry->default_context; + } +} + class ModulePass; /*! @@ -58,38 +156,26 @@ class ModulePassNode : public PassNode { } /*! - * \brief Run a module pass on a certain module. + * \brief Run a module pass on given pass context. * - * \param mod The module that an optimization pass runs on. + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ - Module operator()(const Module& mod) const final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const { return pass_info; } - /*! - * \brief Set the context information for a module pass. - * - * \param pass_ctx The context information for a module pass. - */ - void SetContext(const PassContext& pass_ctx) final; - TVM_DLL static ModulePass make( runtime::TypedPackedFunc pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.ModulePass"; TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode); - - private: - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); @@ -124,26 +210,20 @@ class FunctionPassNode : public PassNode { } /*! - * \brief Run a function pass on a certain module. + * \brief Run a function pass on given pass context. * - * \param mod The module that an optimization pass runs on. + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ - Module operator()(const Module& mod) const final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const { return pass_info; } - /*! - * \brief Set the context information for a function-level pass. - * - * \param pass_ctx The context information for a function-level pass. - */ - void SetContext(const PassContext& pass_ctx) final; - TVM_DLL static FunctionPass make( runtime::TypedPackedFunc pass_func, PassInfo pass_info); @@ -160,11 +240,6 @@ class FunctionPassNode : public PassNode { * \return Return true if the function will be skipped, otherwise false. */ bool SkipFunction(const Function& func) const; - - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); @@ -182,18 +257,17 @@ class SequentialNode : public PassNode { /* \brief The pass meta data.*/ PassInfo pass_info; - /*! \brief A list of passes that used to compose a sequential pass. */ - tvm::Array passes; /*! - * \brief A list of disabled passes that should be excluded when executing the - * sequential pass. + * \brief A helper struct to get the optimization pass name to opt level + * mapping. */ - tvm::Array disabled; + OptPassLevel opt_pass_level; + /*! \brief A list of passes that used to compose a sequential pass. */ + tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); - v->Visit("disabled", &disabled); } /*! @@ -210,6 +284,15 @@ class SequentialNode : public PassNode { passes.push_back(pass); } + /*! + * \brief Check if a pass is enabled. + * + * \param pass_name The name of an optimization/analysis pass. + * + * \return true if the pass is enabled. Otherwise, false. + */ + bool pass_enabled(const std::string& pass_name) const; + /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -224,7 +307,11 @@ class SequentialNode : public PassNode { */ void ResolveDependency(const Module& mod); - TVM_DLL std::vector DisabledPasses() const; + std::unordered_set DisabledPasses( + const Array& disabled) const; + + std::unordered_set RequiredPasses( + const Array& disabled) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -232,27 +319,15 @@ class SequentialNode : public PassNode { * be overloaded to focus on different metrics, i.e. performance, * memory footprint, etc. * - * \param mod The module that an optimization pass runs on. + * \param mod The module that these passes are applied on. + * \param pass_ctx The context that these passes execute on. * * \return Return the updated module. */ - Module operator()(const Module& mod) const final; - - /*! - * \brief Set the context information for a sequential pass. - * - * \param pass_ctx The context information for a sequential pass. - */ - void SetContext(const PassContext& pass_ctx) final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); - - private: - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; PassInfo PassInfoNode::make(int opt_level, std::string name, @@ -264,11 +339,6 @@ PassInfo PassInfoNode::make(int opt_level, std::string name, return PassInfo(pass_info); } -PassContext PassContextNode::make() { - auto ctx = make_node(); - return PassContext(ctx); -} - ModulePass ModulePassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { @@ -279,23 +349,19 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -// TODO(zhiics) 1. Check and handle the required passes. -// 2. Probably use CoW for all places that use module instead of -// returning the updated one. -Module ModulePassNode::operator()(const Module& mod) const { +// TODO(zhiics) Check and handle the required passes. +Module ModulePassNode::operator()(const Module& mod, + const PassContext& pass_ctx) const { PassInfo pass_info = Info(); LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; + CHECK(mod.defined()); - auto updated_mod = pass_func(mod, pass_ctx_); + auto updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } -void ModulePassNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; -} - FunctionPass FunctionPassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { @@ -307,31 +373,22 @@ FunctionPass FunctionPassNode::make( // Perform Module -> Module optimizations at the Function level. // TODO(zhiics) Check and handle the required passes. -Module FunctionPassNode::operator()(const Module& mod) const { +Module FunctionPassNode::operator()(const Module& mod, + const PassContext& pass_ctx) const { PassInfo pass_info = Info(); LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; CHECK(mod.defined()); - std::vector> updated_funcs; - ModuleNode* mod_node = mod.operator->(); - for (const auto& it : mod_node->functions) { - if (!SkipFunction(it.second)) { - auto updated_func = pass_func(it.second, pass_ctx_); - CHECK(updated_func.defined()); - updated_funcs.push_back({std::move(it.first), std::move(updated_func)}); - } - } + Module new_mod = ModuleNode::make({}, mod->type_definitions); - // Update the optimized functions. - for (const auto& it : updated_funcs) { - mod_node->Update(it.first, it.second); + // Execute the pass function and return a new module. + for (const auto& it : mod->functions) { + auto updated_func = + SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx); + new_mod->Add(it.first, updated_func); } - return GetRef(mod_node); -} - -void FunctionPassNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; + return new_mod; } // TODO(zhiics) Create an enum attribute for FunctionNode @@ -342,31 +399,23 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { return pval && pval->value != 0; } -Sequential::Sequential(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled) { +Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { auto n = make_node(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); - n->disabled = std::move(disabled); node_ = std::move(n); } -const SequentialNode* Sequential::operator->() const { - return static_cast(this->node_.get()); +Sequential::Sequential(tvm::Array passes, std::string name) { + auto n = make_node(); + n->passes = std::move(passes); + PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); + n->pass_info = std::move(pass_info); + node_ = std::move(n); } -// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in -// a Sequential without the consideration of their orders. The phase -// ordering problem needed to be handled in the future. -Module SequentialNode::operator()(const Module& module) const { - Module mod = module; - for (const Pass& pass : passes) { - CHECK(pass.defined()) << "Found undefined pass for optimization."; - const auto* pn = pass.operator->(); - mod = (*pn)(mod); - } - return mod; +const SequentialNode* Sequential::operator->() const { + return static_cast(this->node_.get()); } void SequentialNode::ResolveDependency(const Module& mod) { @@ -378,18 +427,68 @@ void SequentialNode::ResolveDependency(const Module& mod) { << "\n"; } -std::vector SequentialNode::DisabledPasses() const { - std::vector ret; +std::unordered_set SequentialNode::DisabledPasses( + const Array& disabled) const { + std::unordered_set ret; for (const auto& it : disabled) { const auto* str = it.as(); CHECK(str) << "disabled passes must be string."; - ret.push_back(str->value); + ret.emplace(str->value); } return ret; } -void SequentialNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; +std::unordered_set SequentialNode::RequiredPasses( + const Array& required) const { + std::unordered_set ret; + for (const auto& it : required) { + const auto* str = it.as(); + CHECK(str) << "disabled passes must be string."; + ret.emplace(str->value); + } + return ret; +} + +bool SequentialNode::pass_enabled(const std::string& pass_name) const { + PassContext ctx = PassContext::Current(); + + const PassContextNode* ctx_node = ctx.operator->(); + auto required = RequiredPasses(ctx_node->required_pass); + auto disabled = DisabledPasses(ctx_node->required_pass); + + if (disabled.count(pass_name)) { + return false; + } + + if (required.count(pass_name)) { + return true; + } + return ctx_node->opt_level >= opt_pass_level[pass_name]; +} + +// TODO(zhiics): we currenlty only sequentially execute each pass in +// a Sequential without the consideration of their orders. The phase +// ordering problem needed to be handled in the future. +Module SequentialNode::operator()(const Module& module, + const PassContext& pass_ctx) const { + const auto* ctx_node = pass_ctx.operator->(); + int opt_level = ctx_node->opt_level; + auto disabled = DisabledPasses(ctx_node->disabled_pass); + Module mod = module; + for (const Pass& pass : passes) { + CHECK(pass.defined()) << "Found undefined pass for optimization."; + PassInfo info = pass->Info(); + const auto& pass_name = info.operator->()->name; + const auto& pass_opt_level = info.operator->()->opt_level; + // Skip the pass if its optimization level is higher that the one of in the + // pass context or if this pass is disabled. + if (pass_opt_level > opt_level || disabled.count(pass_name)) { + continue; + } + const auto* pn = pass.operator->(); + mod = (*pn)(mod, pass_ctx); + } + return mod; } Pass CreateModulePass( @@ -481,9 +580,8 @@ TVM_REGISTER_API("relay._transform.Sequential") int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - tvm::Array disabled = args[4]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - *ret = Sequential(passes, pass_info, disabled); + *ret = Sequential(passes, pass_info); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -501,26 +599,58 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "]"; }); -TVM_REGISTER_API("relay._transform.SetContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Pass pass = args[0]; - PassContext pass_ctx = args[1]; - pass->SetContext(pass_ctx); -}); - TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_API("relay._transform.PassContext") -.set_body_typed(PassContextNode::make); +.set_body([](TVMArgs args, TVMRetValue* ret) { + int opt_level = args[0]; + int fallback_device = args[1]; + tvm::Array required = args[2]; + tvm::Array disabled = args[3]; + *ret = PassContext(opt_level, fallback_device, required, disabled); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PassContextNode* node, - tvm::IRPrinter* p) { - p->stream << "TODO(zhiics): printing context"; - LOG(FATAL) << "PassContext printer has not been implemented yet." - << "\n"; + tvm::IRPrinter* p) { + p->stream << "Pass context information: " << "\n"; + p->stream << "\topt_level: " << node->opt_level << "\n"; + p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level) + << "\n"; + + p->stream << "\trequired passes: [" << node->opt_level; + for (const auto& it : node->required_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + + p->stream << "\tdisabled passes: [" << node->opt_level; + for (const auto& it : node->disabled_pass) { + p->stream << it << " "; + } + p->stream << "]"; }); +class PassContext::Internal { + public: + static void EnterScope(PassContext pass_ctx) { + pass_ctx.EnterWithScope(); + } + + static void ExitScope(PassContext pass_ctx) { + pass_ctx.ExitWithScope(); + } +}; + +TVM_REGISTER_API("relay._transform.GetCurrentPassContext") +.set_body_typed(PassContext::Current); + +TVM_REGISTER_API("relay._transform.EnterPassContext") +.set_body_typed(PassContext::Internal::EnterScope); + +TVM_REGISTER_API("relay._transform.ExitPassContext") +.set_body_typed(PassContext::Internal::ExitScope); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 0fed49079fd2..da78e960091d 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -31,7 +31,7 @@ def get_tvm_output(func, x, params, target, ctx, out_shape=(1, 1000), input_name='image', dtype='float32'): - with relay.build_module.build_config(opt_level=3): + with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap dtype_dict = {input_name: input_data.dtype} func, params = relay.frontend.from_coreml(coreml_model, shape_dict) - with relay.build_module.build_config(opt_level=3): + with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) from tvm.contrib import graph_runtime diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 35a9229443cb..8817d4faaeaa 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -43,7 +43,7 @@ def get_keras_output(xs, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'): shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} func, params = relay.frontend.from_keras(keras_model, shape_dict) - with relay.build_module.build_config(opt_level=2): + with relay.transform.build_config(opt_level=2): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) for name, x in zip(keras_model.input_names, xs): diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index f8686e9d20ab..01669818bb99 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -144,7 +144,7 @@ def extract(path): # target x86 CPU target = "llvm" -with relay.build_module.build_config(opt_level=3): +with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) ###################################################################### From 1e66d2f5a674904f5108ed6e7dbf9291e89b099a Mon Sep 17 00:00:00 2001 From: Siju Date: Sat, 25 May 2019 03:08:08 +0530 Subject: [PATCH 044/176] [RELAY]Frontend darknet (#2773) * [RELAY]Frontend darknet * CI test file updated & CI error fixed * avg_pool pad fix * Changed repo_url and doc formatting --- nnvm/python/nnvm/testing/__init__.py | 1 - .../python/frontend/darknet/test_forward.py | 4 +- nnvm/tutorials/from_darknet.py | 16 +- python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/common.py | 2 +- python/tvm/relay/frontend/darknet.py | 847 ++++++++++++++++++ python/tvm/relay/testing/__init__.py | 1 + .../tvm/relay}/testing/darknet.py | 0 .../tvm/relay}/testing/yolo_detection.py | 0 tests/python/frontend/darknet/test_forward.py | 462 ++++++++++ tests/scripts/task_python_frontend.sh | 9 +- tutorials/frontend/from_darknet.py | 179 ++++ 12 files changed, 1507 insertions(+), 15 deletions(-) create mode 100644 python/tvm/relay/frontend/darknet.py rename {nnvm/python/nnvm => python/tvm/relay}/testing/darknet.py (100%) rename {nnvm/python/nnvm => python/tvm/relay}/testing/yolo_detection.py (100%) create mode 100644 tests/python/frontend/darknet/test_forward.py create mode 100644 tutorials/frontend/from_darknet.py diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 44b8529821d0..41bcf83eb511 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -13,5 +13,4 @@ from . import inception_v3 from . import dcgan from . import dqn -from . import yolo_detection from . import check_computation diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py index 7f45a6149efc..4e62ff2e1f33 100644 --- a/nnvm/tests/python/frontend/darknet/test_forward.py +++ b/nnvm/tests/python/frontend/darknet/test_forward.py @@ -27,8 +27,8 @@ from tvm.contrib.download import download_testdata download_testdata.__test__ = False from nnvm import frontend -from nnvm.testing.darknet import LAYERTYPE -from nnvm.testing.darknet import __darknetffi__ +from tvm.relay.testing.darknet import LAYERTYPE +from tvm.relay.testing.darknet import __darknetffi__ import nnvm.compiler DARKNET_LIB = 'libdarknet2.0.so' diff --git a/nnvm/tutorials/from_darknet.py b/nnvm/tutorials/from_darknet.py index 857ef46015cd..d2ab647da1b3 100644 --- a/nnvm/tutorials/from_darknet.py +++ b/nnvm/tutorials/from_darknet.py @@ -33,8 +33,8 @@ import nnvm import nnvm.frontend.darknet -import nnvm.testing.yolo_detection -import nnvm.testing.darknet +import tvm.relay.testing.yolo_detection +import tvm.relay.testing.darknet import matplotlib.pyplot as plt import numpy as np import tvm @@ -42,7 +42,7 @@ from ctypes import * from tvm.contrib.download import download_testdata -from nnvm.testing.darknet import __darknetffi__ +from tvm.relay.testing.darknet import __darknetffi__ # Model name MODEL_NAME = 'yolov3' @@ -104,7 +104,7 @@ test_image + '?raw=true' img_path = download_testdata(img_url, test_image, "data") -data = nnvm.testing.darknet.load_image(img_path, netw, neth) +data = tvm.relay.testing.darknet.load_image(img_path, netw, neth) ###################################################################### # Execute on TVM Runtime # ---------------------- @@ -153,12 +153,12 @@ # do the detection and bring up the bounding boxes thresh = 0.5 nms_thresh = 0.45 -img = nnvm.testing.darknet.load_image_color(img_path) +img = tvm.relay.testing.darknet.load_image_color(img_path) _, im_h, im_w = img.shape -dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, +dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out) last_layer = net.layers[net.n - 1] -nnvm.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) +tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) coco_name = 'coco.names' coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name + '?raw=true' @@ -172,6 +172,6 @@ names = [x.strip() for x in content] -nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) +tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) plt.imshow(img.transpose(1, 2, 0)) plt.show() diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 8d308c7e8833..76761fd78325 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -30,3 +30,4 @@ from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow +from .darknet import from_darknet diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 9b89936de015..23477626b63b 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -241,7 +241,7 @@ def get_relay_op(op_name): op = None else: # try search op in various modules - for candidate in (_op, _op.nn, _op.image): + for candidate in (_op, _op.nn, _op.image, _op.vision): op = getattr(candidate, op_name, None) if op is not None: break diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py new file mode 100644 index 000000000000..6da3525eec21 --- /dev/null +++ b/python/tvm/relay/frontend/darknet.py @@ -0,0 +1,847 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +""" +DarkNet symbol frontend for Relay. +""" + +from __future__ import absolute_import as _abs +from enum import Enum +import numpy as np +import tvm +from .. import ir_pass +from .. import expr as _expr +from .common import get_relay_op, new_var + +__all__ = ['from_darknet'] + +def _darknet_not_support(attr, op='relay'): + """Raise error if any operation is not supported.""" + err = "{} is not supported in {}.".format(attr, op) + raise NotImplementedError(err) + +def _get_params_prefix(opname, layer_num): + """Makes the params prefix name from opname and layer number.""" + return str(opname) + str(layer_num) + +def _get_params_name(prefix, item): + """Makes the params name for the k,v pair.""" + return prefix + '_'+ item + +def _get_param_var(params, prefix, item): + name = _get_params_name(prefix, item) + if name not in params: + raise AttributeError("{} not found in params dict.".format(name)) + return new_var(name, shape=params[name].shape, dtype=params[name].dtype) + +def _darknet_maxpooling(inputs, params, attrs, prefix): + """Process the max pool 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 1) + new_attrs['pool_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + extra_pad_size = attrs.get('extra_pad_size', 0) + if extra_pad_size: + pad_width = ((0, 0), (0, 0), (0, extra_pad_size), (0, extra_pad_size)) + inputs = [get_relay_op('pad')(*inputs, + pad_width=pad_width, + pad_value=np.finfo(np.float32).min)] + return get_relay_op('max_pool2d')(*inputs, **new_attrs) + +def _darknet_avgpooling(inputs, params, attrs, prefix): + """Process the average pool 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 0) + + new_attrs['pool_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + return get_relay_op('avg_pool2d')(*inputs, **new_attrs) + +def _darknet_conv2d(inputs, params, attrs, prefix): + """Process the convolution 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 0) + + new_attrs['channels'] = attrs.get('num_filter') + new_attrs['kernel_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + + weight = _get_param_var(params, prefix, 'weight') + out = get_relay_op('conv2d')(*inputs, weight=weight, **new_attrs) + + use_bias = not attrs.get('use_batchNorm', False) + if use_bias: + new_attrs = {} + new_attrs['axis'] = 1 + bias = _get_param_var(params, prefix, 'bias') + out = get_relay_op('bias_add')(out, bias=bias, **new_attrs) + else: + new_attrs = {} + new_attrs['epsilon'] = 0.000001 + gamma = _get_param_var(params, prefix, 'gamma') + beta = _get_param_var(params, prefix, 'beta') + moving_mean = _get_param_var(params, prefix, 'moving_mean') + moving_var = _get_param_var(params, prefix, 'moving_var') + out = get_relay_op('batch_norm')(out, gamma, beta, moving_mean, moving_var, **new_attrs) + + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + new_attrs['slope'] = 0.1 + out = _darknet_activations(out, None, new_attrs) + return out + +def _darknet_shortcut(inputs, params, attrs, prefix): + """Process the shortcut operation.""" + input_0 = inputs[0] + input_1 = inputs[1] + + input_0_channel = int(attrs['out_channel']) + input_1_channel = int(attrs['add_out_channel']) + input_0_size = int(attrs['out_size']) + input_1_size = int(attrs['add_out_size']) + + if input_0_size > input_1_size: + scale = int(input_0_size/input_1_size) + input_1 = get_relay_op('upsampling')(input_1, scale=scale) + + elif input_0_size < input_1_size: + stride = int(input_1_size/input_0_size) + input_1 = get_relay_op('avg_pool2d')(input_1, + pool_size=(1, 1), + strides=(stride, stride), + padding=(0, 0)) + + if input_0_channel != input_1_channel: + pad_channel = input_0_channel - input_1_channel + input_1 = get_relay_op('pad')(input_1, + pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)), + pad_value=0.) + sym = input_0 + input_1 + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + sym = _darknet_activations(sym, None, new_attrs) + return sym + +def _darknet_dense(inputs, params, attrs, prefix): + """Process the dense operation.""" + new_attrs = {} + new_attrs['units'] = attrs.get('num_hidden') + data = inputs[0] + + if attrs.get('use_flatten', False) is True: + data = get_relay_op('batch_flatten')(data) + + weight = _get_param_var(params, prefix, 'weight') + data = get_relay_op('dense')(data, weight, **new_attrs) + + use_bias = attrs.get('use_bias', False) + if use_bias: + bias = _get_param_var(params, prefix, 'bias') + data = get_relay_op('bias_add')(data, bias, axis=1) + + if 'use_batchNorm' in attrs: + new_attrs = {} + new_attrs['epsilon'] = 0.000001 + gamma = _get_param_var(params, prefix, 'gamma') + beta = _get_param_var(params, prefix, 'beta') + moving_mean = _get_param_var(params, prefix, 'moving_mean') + moving_var = _get_param_var(params, prefix, 'moving_var') + data = get_relay_op('batch_norm')(data, gamma, beta, moving_mean, moving_var, **new_attrs) + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + data = _darknet_activations(data, None, new_attrs) + return data + +def _darknet_dropout(inputs, params, attrs, prefix): + """Process the dropout operation, its a blank operation.""" + new_attrs = {} + new_attrs['rate'] = attrs.get('p', 0.5) + return get_relay_op('dropout')(*inputs, **new_attrs) + +def _darknet_reshape(inputs, params, attrs, prefix): + """Process the reshape operation.""" + new_attrs = {} + new_attrs['shape'] = attrs.get('shape') + return get_relay_op('reshape')(*inputs, **new_attrs) + +def _darknet_upsampling(inputs, params, attrs, prefix): + """Process the upsampling operation.""" + new_attrs = {} + new_attrs['scale'] = attrs.get('scale', 1) + return get_relay_op('upsampling')(*inputs, **new_attrs) + +def _darknet_l2normalize(inputs, params, attrs, prefix): + """Process the l2 normalization operation.""" + new_attrs = {} + new_attrs['eps'] = attrs.get('eps', 0.0) + new_attrs['axis'] = [attrs.get('axis', 1)] + return get_relay_op('l2_normalize')(*inputs, **new_attrs) + +def _darknet_softmax_output(inputs, params, attrs, prefix): + """Process the softmax operation.""" + temperature = attrs.get('temperature', 1) + data = inputs[0] + if temperature != 1: + data = data / _expr.const(float(temperature)) + + if attrs.get('use_flatten', False) is True: + data = get_relay_op('batch_flatten')(data) + + new_attrs = {} + if attrs.get('multi_output', False): + new_attrs['axis'] = 1 + return get_relay_op('softmax')(data, **new_attrs) + +def _darknet_route(inputs, params, attrs, prefix): + """Process the route operation, which is equivalent to concat.""" + new_attrs = {'axis': attrs.get('dim', 1)} + return get_relay_op('concatenate')((inputs[0], inputs[1]), **new_attrs) + +def _darknet_reorg(inputs, params, attrs, prefix): + """Process the reorg operation.""" + new_attrs = {} + if 'stride' in attrs: + new_attrs = {'stride': attrs.get('stride', 1)} + return get_relay_op('yolo_reorg')(*inputs, **new_attrs) + +def _darknet_region(inputs, params, attrs, prefix): + """Process the region operation.""" + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + coords = attrs.get('coords', 0) + background = attrs.get('background', 0) + softmax = attrs.get('softmax', True) + input_shape = attrs.get('shape') + + split_size = classes + coords + 1 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape) + split_indices = (2, 4, 5) + split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = get_relay_op('sigmoid')(split_res[0]) + split_res2 = split_res[2] if background else get_relay_op('sigmoid')(split_res[2]) + split_res3 = get_relay_op('softmax')(split_res[3], axis=2) if softmax else split_res[3] + out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2, split_res3), axis=2) + return get_relay_op('reshape')(out, newshape=input_shape) + +def _darknet_yolo(inputs, params, attrs, prefix): + """Process the yolo operation.""" + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + input_shape = attrs.get('shape') + split_size = classes + 5 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape) + split_indices = (2, 4) + split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = get_relay_op('sigmoid')(split_res[0]) + split_res2 = get_relay_op('sigmoid')(split_res[2]) + out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2), axis=2) + return get_relay_op('reshape')(out, newshape=input_shape) + +class ACTIVATION(object): + """Darknet ACTIVATION Class constant.""" + LOGISTIC = 0 + RELU = 1 + RELIE = 2 + LINEAR = 3 + RAMP = 4 + TANH = 5 + PLSE = 6 + LEAKY = 7 + ELU = 8 + LOGGY = 9 + STAIR = 10 + HARDTAN = 11 + LHTAN = 12 + +def _darknet_activations(inputs, params, attrs): + """Process the activation function.""" + act = attrs.get('activation') + data = inputs[0] if isinstance(inputs, _expr.TupleWrapper) else inputs + + def _const(val): + return _expr.const(val) + + def _relu(data): + return get_relay_op('relu')(data) + + def _exp(data): + return get_relay_op('exp')(data) + + def _tanh(data): + return get_relay_op('tanh')(data) + + def _sigmoid(data): + return get_relay_op('sigmoid')(data) + + def _elu(data): + alpha = _const(-1.0) + return alpha * _relu(_const(1.0) - _exp(data)) + _relu(data) + + def _leaky_relu(data, slope): + new_attrs = {} + new_attrs['alpha'] = slope + return get_relay_op('leaky_relu')(data, **new_attrs) + + if ACTIVATION.LOGISTIC == act: + data = _sigmoid(data) + elif ACTIVATION.RELU == act: + data = _relu(data) + elif ACTIVATION.TANH == act: + data = _tanh(data) + elif ACTIVATION.LINEAR == act: + return data + elif ACTIVATION.LEAKY == act: + data = _leaky_relu(data, attrs.get('slope', 0.1)) + elif ACTIVATION.ELU == act: + data = _elu(data) + else: + _darknet_not_support('act: ' + attrs) + return data + +class LAYERTYPE(Enum): + """Darknet LAYERTYPE Class constant.""" + CONVOLUTIONAL = 0 + DECONVOLUTIONAL = 1 + CONNECTED = 2 + MAXPOOL = 3 + SOFTMAX = 4 + DETECTION = 5 + DROPOUT = 6 + CROP = 7 + ROUTE = 8 + COST = 9 + NORMALIZATION = 10 + AVGPOOL = 11 + LOCAL = 12 + SHORTCUT = 13 + ACTIVE = 14 + RNN = 15 + GRU = 16 + LSTM = 17 + CRNN = 18 + BATCHNORM = 19 + NETWORK = 20 + XNOR = 21 + REGION = 22 + YOLO = 23 + REORG = 24 + UPSAMPLE = 25 + LOGXENT = 26 + L2NORM = 27 + BLANK = 28 + +_DARKNET_CONVERT_MAP = { + LAYERTYPE.CONVOLUTIONAL : _darknet_conv2d, + LAYERTYPE.CONNECTED : _darknet_dense, + LAYERTYPE.MAXPOOL : _darknet_maxpooling, + LAYERTYPE.SOFTMAX : _darknet_softmax_output, + LAYERTYPE.DROPOUT : _darknet_dropout, + LAYERTYPE.AVGPOOL : _darknet_avgpooling, + LAYERTYPE.ROUTE : _darknet_route, + LAYERTYPE.REORG : _darknet_reorg, + LAYERTYPE.REGION : _darknet_region, + LAYERTYPE.SHORTCUT : _darknet_shortcut, + LAYERTYPE.UPSAMPLE : _darknet_upsampling, + LAYERTYPE.L2NORM : _darknet_l2normalize, + LAYERTYPE.YOLO : _darknet_yolo, + LAYERTYPE.DECONVOLUTIONAL : _darknet_not_support, + LAYERTYPE.BATCHNORM : _darknet_not_support, + LAYERTYPE.DETECTION : _darknet_not_support, + LAYERTYPE.CROP : _darknet_not_support, + LAYERTYPE.COST : _darknet_not_support, + LAYERTYPE.NORMALIZATION : _darknet_not_support, + LAYERTYPE.LOCAL : _darknet_not_support, + LAYERTYPE.ACTIVE : _darknet_not_support, + LAYERTYPE.RNN : _darknet_not_support, + LAYERTYPE.GRU : _darknet_not_support, + LAYERTYPE.LSTM : _darknet_not_support, + LAYERTYPE.CRNN : _darknet_not_support, + LAYERTYPE.NETWORK : _darknet_not_support, + LAYERTYPE.XNOR : _darknet_not_support, + LAYERTYPE.BLANK : _darknet_not_support, +} + +def _darknet_convert_symbol(op_name, inputs, params, attrs, params_prefix): + """Convert from darknet op to relay op. + Parameters + ---------- + op_name : str + Operator name, such as Convolution, Connected, etc + inputs : list of relay.Function + List of input symbols. + attrs : dict + Dict of operator attributes + params_prefix: str + Params name for this operation + + Returns + ------- + out_name : converted out name of operation + sym : tvm.relay.Function + Converted relay function + """ + + if op_name in _DARKNET_CONVERT_MAP: + sym = _DARKNET_CONVERT_MAP[op_name](inputs, params, attrs, params_prefix) + else: + _darknet_not_support('Operator type ' + str(op_name)) + return sym + +def _as_list(arr): + """Force being a list, ignore if already is.""" + if isinstance(arr, list): + return arr + return [arr] + +class GraphProto(object): + """A helper class for handling relay functions from darknet model. + """ + + def __init__(self, net, shape, dtype='float32'): + self._net = net + self._shape = shape + self._dtype = dtype + self._sym_array = {} + self._tvmparams = {} + self._outs = [] + self._state_ctr = {} + self._state_ctr['rnn'] = 0 + self._state_ctr['crnn'] = 0 + self._state_ctr['lstm'] = 0 + self._state_ctr['cell_state'] = 0 + self._state_ctr['gru'] = 0 + + def _read_memory_buffer(self, shape, data, dtype=None): + if dtype is None: + dtype = self._dtype + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) + + def _get_convolution_weights(self, layer, opname): + """Get the convolution layer weights and biases.""" + if layer.nweights == 0: + return None + + if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: + raise RuntimeError("layer weights size not matching with n c h w") + + params = {} + shape = (layer.n, layer.c, layer.size, layer.size) + weights = self._read_memory_buffer(shape, layer.weights) + + biases = self._read_memory_buffer((layer.n, ), layer.biases) + + k = _get_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + params.update(self._get_batchnorm_weights(layer, opname, layer.n)) + k = _get_params_name(opname, 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + return params + + def _get_connected_weights(self, layer, opname): + """Parse the weights and biases for fully connected or dense layer.""" + size = layer.outputs * layer.inputs + if size == 0: + return None + + weights = self._read_memory_buffer((layer.outputs, layer.inputs), layer.weights) + biases = self._read_memory_buffer((layer.outputs, ), layer.biases) + + params = {} + k = _get_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + params.update(self._get_batchnorm_weights(layer, opname, layer.outputs)) + k = _get_params_name(opname, 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + return params + + def _get_region_weights(self, layer, opname): + """Parse the biases for region layer.""" + biases = self._read_memory_buffer((layer.n*2, ), layer.biases) + attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w, + layer.classes, layer.coords, layer.background], + dtype=np.int32) + params = {} + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + k = _get_params_name(opname, 'attr') + params[k] = tvm.nd.array(attributes) + return params + + def _get_yolo_weights(self, layer, opname): + """Parse the biases and mask for yolo layer.""" + biases = self._read_memory_buffer((layer.total*2, ), layer.biases) + mask = self._read_memory_buffer((layer.n, ), layer.mask, dtype='int32') + attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w, + layer.classes, layer.total], + dtype=np.int32) + params = {} + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + k = _get_params_name(opname, 'mask') + params[k] = tvm.nd.array(mask) + k = _get_params_name(opname, 'attr') + params[k] = tvm.nd.array(attributes) + return params + + def _get_batchnorm_weights(self, layer, opname, size): + """Parse the weights for batchnorm, which includes, scales, moving mean + and moving variances.""" + scales = self._read_memory_buffer((size, ), layer.scales) + rolling_mean = self._read_memory_buffer((size, ), layer.rolling_mean) + rolling_variance = self._read_memory_buffer((size, ), layer.rolling_variance) + + params = {} + k = _get_params_name(opname, 'moving_mean') + params[k] = tvm.nd.array(rolling_mean) + k = _get_params_name(opname, 'moving_var') + params[k] = tvm.nd.array(rolling_variance) + k = _get_params_name(opname, 'gamma') + params[k] = tvm.nd.array(scales) + return params + + def _get_darknet_attrs(self, layer, layer_num): + """Parse attributes of each layer and return.""" + attr = {} + use_flatten = True + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.CONVOLUTIONAL == layer_type: + attr.update({'pad' : layer.pad}) + attr.update({'num_group' : layer.groups}) + attr.update({'num_filter' : layer.n}) + attr.update({'stride' : layer.stride}) + attr.update({'kernel' : layer.size}) + attr.update({'activation' : (layer.activation)}) + + if layer.nbiases == 0: + attr.update({'use_bias' : False}) + else: + attr.update({'use_bias' : True}) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + + elif LAYERTYPE.CONNECTED == layer_type: + attr.update({'num_hidden' : layer.outputs}) + attr.update({'activation' : (layer.activation)}) + if layer_num != 0: + layer_prev = self._net.layers[layer_num - 1] + if (layer_prev.out_h == layer.h and + layer_prev.out_w == layer.w and + layer_prev.out_c == layer.c): + use_flatten = False + attr.update({'use_flatten' : use_flatten}) + attr.update({'use_bias' : True}) + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + attr.update({'use_bias' : False}) + + elif LAYERTYPE.MAXPOOL == layer_type: + attr.update({'pad' : layer.pad}) + attr.update({'stride' : layer.stride}) + attr.update({'kernel' : layer.size}) + max_output = (layer.w - layer.size + 2 * layer.pad)/float(layer.stride) + 1 + if max_output < layer.out_w: + extra_pad = (layer.out_w - max_output)*layer.stride + attr.update({'extra_pad_size' : int(extra_pad)}) + elif LAYERTYPE.AVGPOOL == layer_type: + attr.update({'pad' : layer.pad}) + if layer.stride == 0: + attr.update({'stride' : 1}) + else: + attr.update({'stride' : layer.stride}) + if layer.size == 0 and layer.h == layer.w: + attr.update({'kernel' : layer.h}) + else: + attr.update({'kernel' : layer.size}) + + elif LAYERTYPE.DROPOUT == layer_type: + attr.update({'p' : layer.probability}) + + elif LAYERTYPE.SOFTMAX == layer_type: + attr.update({'axis' : 1}) + attr.update({'use_flatten' : True}) + if layer.temperature: + attr.update({'temperature' : str(layer.temperature)}) + + elif LAYERTYPE.SHORTCUT == layer_type: + add_layer = self._net.layers[layer.index] + attr.update({'activation' : layer.activation}) + attr.update({'out_channel' : layer.out_c}) + attr.update({'out_size' : layer.out_h}) + attr.update({'add_out_channel' : add_layer.out_c}) + attr.update({'add_out_size' : add_layer.out_h}) + + elif LAYERTYPE.ROUTE == layer_type: + pass + + elif LAYERTYPE.COST == layer_type: + pass + + elif LAYERTYPE.REORG == layer_type: + attr.update({'stride' : layer.stride}) + + elif LAYERTYPE.REGION == layer_type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'coords' : layer.coords}) + attr.update({'background' : layer.background}) + attr.update({'softmax' : layer.softmax}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + + elif LAYERTYPE.YOLO == layer_type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + + elif LAYERTYPE.UPSAMPLE == layer_type: + attr.update({'scale' : layer.stride}) + + elif LAYERTYPE.L2NORM == layer_type: + pass + + else: + err = "Darknet layer type {} is not supported in relay.".format(layer_type) + raise NotImplementedError(err) + + return attr + + def _get_darknet_params(self, layer, opname): + """To parse and get the darknet params.""" + layer_type = LAYERTYPE(layer.type) + params = None + if LAYERTYPE.CONVOLUTIONAL == layer_type: + params = self._get_convolution_weights(layer, opname) + elif LAYERTYPE.CONNECTED == layer_type: + params = self._get_connected_weights(layer, opname) + elif LAYERTYPE.REGION == layer_type: + params = self._get_region_weights(layer, opname) + elif LAYERTYPE.YOLO == layer_type: + params = self._get_yolo_weights(layer, opname) + return params + + def _preproc_layer(self, layer, layer_num): + """To preprocess each darknet layer, some layer doesnt need processing.""" + if layer_num == 0: + name = 'data' + sym = new_var(name, shape=self._shape, dtype=self._dtype) + else: + sym = self._sym_array[layer_num - 1] + skip_layer = False + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.ROUTE == layer_type: + sym = [] + for j in range(layer.n): + sym.append(self._sym_array[layer.input_layers[j]]) + if layer.n == 1: + skip_layer = True + + elif LAYERTYPE.COST == layer_type: + skip_layer = True + + elif LAYERTYPE.SHORTCUT == layer_type: + sym = [sym, self._sym_array[layer.index]] + + elif LAYERTYPE.BLANK == layer_type: + skip_layer = True + + if skip_layer is True: + self._sym_array[layer_num] = sym + + return skip_layer, sym + + def _get_opname(self, layer): + """Returs the layer name.""" + return LAYERTYPE(layer.type) + + def _new_rnn_state_var(self, state=None, name='rnn'): + """Returs a symbol for state""" + sym_name = name + "%d_state" % self._state_ctr[name] + self._state_ctr[name] += 1 + return new_var(sym_name, shape=state.shape, dtype=str(state.dtype)) + + def _get_rnn_state_buffer(self, layer, name): + """Get the state buffer for rnn.""" + buffer = np.zeros((1, layer.outputs), self._dtype) + return self._new_rnn_state_var(buffer, name) + + def _get_darknet_rnn_attrs(self, layer, name, sym): + """Get the rnn converted symbol from attributes.""" + attr = self._get_darknet_attrs(layer, 0) + op_name = self._get_opname(layer) + prefix = _get_params_prefix(op_name, name) + params = self._get_darknet_params(layer, prefix) + sym = _darknet_convert_symbol(op_name, _as_list(sym), params, attr, prefix) + if params: + self._tvmparams.update(params) + return sym + + def _handle_darknet_rnn_layers(self, layer_num, sym): + """Parse attributes and handle the rnn layers.""" + attr = {} + layer = self._net.layers[layer_num] + processed = False + + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.RNN == layer_type: + attr.update({'n' : layer.n}) + attr.update({'batch' : layer.batch}) + attr.update({'num_hidden' : str(layer.outputs)}) + state = self._get_rnn_state_buffer(layer, 'rnn') + for _ in range(layer.steps): + input_layer = layer.input_layer + prefix = "_input_" + str(layer_num) + sym = self._get_darknet_rnn_attrs(input_layer, prefix, sym) + + self_layer = layer.self_layer + prefix = "_self_" + str(layer_num) + state = self._get_darknet_rnn_attrs(self_layer, prefix, state) + + state = sym + state + self._outs.append(state) + + output_layer = layer.output_layer + prefix = "_output_" + str(layer_num) + sym = self._get_darknet_rnn_attrs(output_layer, prefix, state) + + self._sym_array[layer_num] = sym + processed = True + return processed, sym + + def _make_outlist(self, sym, op_name, layer, layer_num): + layer_type = LAYERTYPE(layer.type) + if layer_type == LAYERTYPE.REGION: + #Add attributes + k = _get_params_name(op_name, 'attr') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add bias + k = _get_params_name(op_name, 'bias') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + if layer_num != self._net.n-1: + self._outs.insert(0, sym) + + elif layer_type == LAYERTYPE.YOLO: + #Add attributes + k = _get_params_name(op_name, 'attr') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add bias + k = _get_params_name(op_name, 'bias') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add mask + k = _get_params_name(op_name, 'mask') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + if layer_num != self._net.n-1: + self._outs.insert(0, sym) + + def from_darknet(self): + """To convert the darknet symbol to relay functions.""" + for i in range(self._net.n): + layer = self._net.layers[i] + need_skip, sym = self._preproc_layer(layer, i) + if need_skip: + continue + + processed, sym = self._handle_darknet_rnn_layers(i, sym) + if processed: + continue + + attr = self._get_darknet_attrs(layer, i) + op_name = self._get_opname(layer) + prefix = _get_params_prefix(op_name, i) + params = self._get_darknet_params(self._net.layers[i], prefix) + sym = _darknet_convert_symbol(op_name, _as_list(sym), params, attr, prefix) + + if params: + self._tvmparams.update(params) + self._sym_array[i] = sym + self._make_outlist(sym, prefix, layer, i) + + outputs = _as_list(sym) + self._outs + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + sym = _expr.Function(ir_pass.free_vars(outputs), outputs) + return sym, self._tvmparams + +def from_darknet(net, + shape=None, + dtype="float32"): + """Convert from Darknet's model into compatible relay Function. + + Parameters + ---------- + net : Darknet net parameter + Darknet net structure. + shape : dict of str to tuple, optional + The input shape to the graph + dtype : str or dict of str to str + The input types to the graph + + Returns + ------- + sym : tvm.relay.Function + Compatible relay Function + params : dict of str to tvm.NDArray + The parameter dict to be used by relay + """ + + return GraphProto(net, shape, dtype).from_darknet() diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 192afe1ef914..7a5007bbfb8f 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -27,6 +27,7 @@ from . import squeezenet from . import vgg from . import densenet +from . import yolo_detection from .config import ctx_list from .init import create_workload diff --git a/nnvm/python/nnvm/testing/darknet.py b/python/tvm/relay/testing/darknet.py similarity index 100% rename from nnvm/python/nnvm/testing/darknet.py rename to python/tvm/relay/testing/darknet.py diff --git a/nnvm/python/nnvm/testing/yolo_detection.py b/python/tvm/relay/testing/yolo_detection.py similarity index 100% rename from nnvm/python/nnvm/testing/yolo_detection.py rename to python/tvm/relay/testing/yolo_detection.py diff --git a/tests/python/frontend/darknet/test_forward.py b/tests/python/frontend/darknet/test_forward.py new file mode 100644 index 000000000000..3545e8a902bd --- /dev/null +++ b/tests/python/frontend/darknet/test_forward.py @@ -0,0 +1,462 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test Darknet Models +=================== +This article is a test script to test darknet models with Relay. +All the required models and libraries will be downloaded from the internet +by the script. +""" +import numpy as np +import tvm +from tvm.contrib import graph_runtime +from tvm.contrib.download import download_testdata +download_testdata.__test__ = False +from tvm.relay.testing.darknet import LAYERTYPE +from tvm.relay.testing.darknet import __darknetffi__ +from tvm.relay.frontend.darknet import ACTIVATION +from tvm import relay + +REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/' +DARKNET_LIB = 'libdarknet2.0.so' +DARKNETLIB_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true' +LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet')) + +DARKNET_TEST_IMAGE_NAME = 'dog.jpg' +DARKNET_TEST_IMAGE_URL = REPO_URL + 'data/' + DARKNET_TEST_IMAGE_NAME +'?raw=true' +DARKNET_TEST_IMAGE_PATH = download_testdata(DARKNET_TEST_IMAGE_URL, DARKNET_TEST_IMAGE_NAME, module='data') + +def _read_memory_buffer(shape, data, dtype='float32'): + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) + +def _get_tvm_output(net, data, build_dtype='float32', states=None): + '''Compute TVM output''' + dtype = 'float32' + sym, params = relay.frontend.from_darknet(net, data.shape, dtype) + target = 'llvm' + shape_dict = {'data': data.shape} + graph, library, params = relay.build(sym, target, params=params) + + # Execute on TVM + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, library, ctx) + # set inputs + m.set_input('data', tvm.nd.array(data.astype(dtype))) + if states: + for name in states.keys(): + m.set_input(name, tvm.nd.array(states[name].astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + tvm_out = [] + for i in range(m.get_num_outputs()): + tvm_out.append(m.get_output(i).asnumpy()) + return tvm_out + +def _load_net(cfg_url, cfg_name, weights_url, weights_name): + cfg_path = download_testdata(cfg_url, cfg_name, module='darknet') + weights_path = download_testdata(weights_url, weights_name, module='darknet') + net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0) + return net + +def verify_darknet_frontend(net, build_dtype='float32'): + '''Test network with given input image on both darknet and tvm''' + def get_darknet_output(net, img): + LIB.network_predict_image(net, img) + out = [] + for i in range(net.n): + layer = net.layers[i] + if layer.type == LAYERTYPE.REGION: + attributes = np.array([layer.n, layer.out_c, layer.out_h, + layer.out_w, layer.classes, + layer.coords, layer.background], + dtype=np.int32) + out.insert(0, attributes) + out.insert(0, _read_memory_buffer((layer.n*2, ), layer.biases)) + layer_outshape = (layer.batch, layer.out_c, + layer.out_h, layer.out_w) + out.insert(0, _read_memory_buffer(layer_outshape, layer.output)) + elif layer.type == LAYERTYPE.YOLO: + attributes = np.array([layer.n, layer.out_c, layer.out_h, + layer.out_w, layer.classes, + layer.total], + dtype=np.int32) + out.insert(0, attributes) + out.insert(0, _read_memory_buffer((layer.total*2, ), layer.biases)) + out.insert(0, _read_memory_buffer((layer.n, ), layer.mask, dtype='int32')) + layer_outshape = (layer.batch, layer.out_c, + layer.out_h, layer.out_w) + out.insert(0, _read_memory_buffer(layer_outshape, layer.output)) + elif i == net.n-1: + if layer.type == LAYERTYPE.CONNECTED: + darknet_outshape = (layer.batch, layer.out_c) + elif layer.type in [LAYERTYPE.SOFTMAX]: + darknet_outshape = (layer.batch, layer.outputs) + else: + darknet_outshape = (layer.batch, layer.out_c, + layer.out_h, layer.out_w) + out.insert(0, _read_memory_buffer(darknet_outshape, layer.output)) + return out + + dtype = 'float32' + + img = LIB.letterbox_image(LIB.load_image_color(DARKNET_TEST_IMAGE_PATH.encode('utf-8'), 0, 0), net.w, net.h) + darknet_output = get_darknet_output(net, img) + batch_size = 1 + data = np.empty([batch_size, img.c, img.h, img.w], dtype) + i = 0 + for c in range(img.c): + for h in range(img.h): + for k in range(img.w): + data[0][c][h][k] = img.data[i] + i = i + 1 + + tvm_out = _get_tvm_output(net, data, build_dtype) + for tvm_outs, darknet_out in zip(tvm_out, darknet_output): + tvm.testing.assert_allclose(darknet_out, tvm_outs, rtol=1e-3, atol=1e-3) + +def _test_rnn_network(net, states): + '''Test network with given input data on both darknet and tvm''' + def get_darknet_network_predict(net, data): + return LIB.network_predict(net, data) + from cffi import FFI + ffi = FFI() + np_arr = np.zeros([1, net.inputs], dtype='float32') + np_arr[0, 2] = 1 + cffi_arr = ffi.cast('float*', np_arr.ctypes.data) + tvm_out = _get_tvm_output(net, np_arr, states=states)[0] + darknet_output = get_darknet_network_predict(net, cffi_arr) + darknet_out = np.zeros(net.outputs, dtype='float32') + for i in range(net.outputs): + darknet_out[i] = darknet_output[i] + last_layer = net.layers[net.n-1] + darknet_outshape = (last_layer.batch, last_layer.outputs) + darknet_out = darknet_out.reshape(darknet_outshape) + tvm.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-4, atol=1e-4) + +def test_forward_extraction(): + '''test extraction model''' + model_name = 'extraction' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_alexnet(): + '''test alexnet model''' + model_name = 'alexnet' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_resnet50(): + '''test resnet50 model''' + model_name = 'resnet50' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_yolov2(): + '''test yolov2 model''' + model_name = 'yolov2' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_yolov3(): + '''test yolov3 model''' + model_name = 'yolov3' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_convolutional(): + '''test convolutional layer''' + net = LIB.make_network(1) + layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_dense(): + '''test fully connected layer''' + net = LIB.make_network(1) + layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0) + net.layers[0] = layer + net.w = net.h = 5 + LIB.resize_network(net, 5, 5) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_dense_batchnorm(): + '''test fully connected layer with batchnorm''' + net = LIB.make_network(1) + layer = LIB.make_connected_layer(1, 12, 2, 1, 1, 0) + for i in range(5): + layer.rolling_mean[i] = np.random.rand(1) + layer.rolling_variance[i] = np.random.rand(1) + layer.scales[i] = np.random.rand(1) + net.layers[0] = layer + net.w = net.h = 2 + LIB.resize_network(net, 2, 2) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_maxpooling(): + '''test maxpooling layer''' + net = LIB.make_network(1) + layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_avgpooling(): + '''test avgerage pooling layer''' + net = LIB.make_network(1) + layer = LIB.make_avgpool_layer(1, 224, 224, 3) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_conv_batch_norm(): + '''test batch normalization layer''' + net = LIB.make_network(1) + layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0) + for i in range(32): + layer.rolling_mean[i] = np.random.rand(1) + layer.rolling_variance[i] = np.random.rand(1) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_shortcut(): + '''test shortcut layer''' + net = LIB.make_network(3) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0) + layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32) + layer_3.activation = ACTIVATION.RELU + layer_3.alpha = 1 + layer_3.beta = 1 + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.layers[2] = layer_3 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_reorg(): + '''test reorg layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0) + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 222 + LIB.resize_network(net, 222, 222) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_region(): + '''test region layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 19, 19, 3, 425, 1, 1, 1, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_region_layer(1, 19, 19, 5, 80, 4) + layer_2.softmax = 1 + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 19 + LIB.resize_network(net, 19, 19) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_yolo_op(): + '''test yolo layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 14, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_yolo_layer(1, 111, 111, 2, 9, __darknetffi__.NULL, 2) + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_upsample(): + '''test upsample layer''' + net = LIB.make_network(1) + layer = LIB.make_upsample_layer(1, 19, 19, 3, 3) + layer.scale = 1 + net.layers[0] = layer + net.w = net.h = 19 + LIB.resize_network(net, 19, 19) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_l2normalize(): + '''test l2 normalization layer''' + net = LIB.make_network(1) + layer = LIB.make_l2norm_layer(1, 224*224*3) + layer.c = layer.out_c = 3 + layer.h = layer.out_h = 224 + layer.w = layer.out_w = 224 + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_elu(): + '''test elu activation layer''' + net = LIB.make_network(1) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_1.activation = ACTIVATION.ELU + net.layers[0] = layer_1 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_softmax(): + '''test softmax layer''' + net = LIB.make_network(1) + layer_1 = LIB.make_softmax_layer(1, 75, 1) + layer_1.temperature = 1 + net.layers[0] = layer_1 + net.w = net.h = 5 + LIB.resize_network(net, net.w, net.h) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_softmax_temperature(): + '''test softmax layer''' + net = LIB.make_network(1) + layer_1 = LIB.make_softmax_layer(1, 75, 1) + layer_1.temperature = 0.8 + net.layers[0] = layer_1 + net.w = net.h = 5 + LIB.resize_network(net, net.w, net.h) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_activation_logistic(): + '''test logistic activation layer''' + net = LIB.make_network(1) + batch = 1 + h = 224 + w = 224 + c = 3 + n = 32 + groups = 1 + size = 3 + stride = 2 + padding = 0 + activation = ACTIVATION.LOGISTIC + batch_normalize = 0 + binary = 0 + xnor = 0 + adam = 0 + layer_1 = LIB.make_convolutional_layer(batch, h, w, c, n, groups, size, stride, padding, + activation, batch_normalize, binary, xnor, adam) + net.layers[0] = layer_1 + net.w = w + net.h = h + LIB.resize_network(net, net.w, net.h) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_rnn(): + '''test RNN layer''' + net = LIB.make_network(1) + batch = 1 + inputs = 4 + outputs = 4 + steps = 1 + activation = ACTIVATION.RELU + batch_normalize = 0 + adam = 0 + layer_1 = LIB.make_rnn_layer(batch, inputs, outputs, steps, activation, batch_normalize, adam) + net.layers[0] = layer_1 + net.inputs = inputs + net.outputs = outputs + net.w = net.h = 0 + LIB.resize_network(net, net.w, net.h) + states = {"rnn0_state": np.zeros([1, net.inputs])} + _test_rnn_network(net, states) + LIB.free_network(net) + +if __name__ == '__main__': + test_forward_resnet50() + test_forward_alexnet() + test_forward_extraction() + test_forward_yolov2() + test_forward_yolov3() + test_forward_convolutional() + test_forward_maxpooling() + test_forward_avgpooling() + test_forward_conv_batch_norm() + test_forward_shortcut() + test_forward_dense() + test_forward_dense_batchnorm() + test_forward_softmax() + test_forward_softmax_temperature() + test_forward_reorg() + test_forward_region() + test_forward_yolo_op() + test_forward_upsample() + test_forward_l2normalize() + test_forward_elu() + test_forward_rnn() + test_forward_activation_logistic() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 37159dbc9a58..609b00149bad 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -62,10 +62,10 @@ python3 -m nose -v tests/python/frontend/mxnet echo "Running relay Keras frontend test..." python3 -m nose -v tests/python/frontend/keras -echo "Running relay ONNX frondend test..." +echo "Running relay ONNX frontend test..." python3 -m nose -v tests/python/frontend/onnx -echo "Running relay CoreML frondend test..." +echo "Running relay CoreML frontend test..." python3 -m nose -v tests/python/frontend/coreml echo "Running nnvm to relay frontend test..." @@ -74,5 +74,8 @@ python3 -m nose -v tests/python/frontend/nnvm_to_relay echo "Running relay Tensorflow frontend test..." python3 -m nose -v tests/python/frontend/tensorflow -echo "Running relay caffe2 frondend test..." +echo "Running relay caffe2 frontend test..." python3 -m nose -v tests/python/frontend/caffe2 + +echo "Running relay DarkNet frontend test..." +python3 -m nose -v tests/python/frontend/darknet || exit -1 diff --git a/tutorials/frontend/from_darknet.py b/tutorials/frontend/from_darknet.py new file mode 100644 index 000000000000..2658a353e34e --- /dev/null +++ b/tutorials/frontend/from_darknet.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile YOLO-V2 and YOLO-V3 in DarkNet Models +============================================= +**Author**: `Siju Samuel `_ + +This article is an introductory tutorial to deploy darknet models with TVM. +All the required models and libraries will be downloaded from the internet by the script. +This script runs the YOLO-V2 and YOLO-V3 Model with the bounding boxes +Darknet parsing have dependancy with CFFI and CV2 library +Please install CFFI and CV2 before executing this script + +.. code-block:: bash + + pip install cffi + pip install opencv-python +""" + +# numpy and matplotlib +import numpy as np +import matplotlib.pyplot as plt +import sys + +# tvm, relay +import tvm +from tvm import relay +from ctypes import * +from tvm.contrib.download import download_testdata +from tvm.relay.testing.darknet import __darknetffi__ +import tvm.relay.testing.yolo_detection +import tvm.relay.testing.darknet + +# Model name +MODEL_NAME = 'yolov3' + +###################################################################### +# Download required files +# ----------------------- +# Download cfg and weights file if first time. +CFG_NAME = MODEL_NAME + '.cfg' +WEIGHTS_NAME = MODEL_NAME + '.weights' +REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/' +CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true' +WEIGHTS_URL = 'https://pjreddie.com/media/files/' + WEIGHTS_NAME + +cfg_path = download_testdata(CFG_URL, CFG_NAME, module="darknet") +weights_path = download_testdata(WEIGHTS_URL, WEIGHTS_NAME, module="darknet") + +# Download and Load darknet library +if sys.platform in ['linux', 'linux2']: + DARKNET_LIB = 'libdarknet2.0.so' + DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true' +elif sys.platform == 'darwin': + DARKNET_LIB = 'libdarknet_mac2.0.so' + DARKNET_URL = REPO_URL + 'lib_osx/' + DARKNET_LIB + '?raw=true' +else: + err = "Darknet lib is not supported on {} platform".format(sys.platform) + raise NotImplementedError(err) + +lib_path = download_testdata(DARKNET_URL, DARKNET_LIB, module="darknet") + +DARKNET_LIB = __darknetffi__.dlopen(lib_path) +net = DARKNET_LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0) +dtype = 'float32' +batch_size = 1 + +data = np.empty([batch_size, net.c, net.h, net.w], dtype) +shape_dict = {'data': data.shape} +print("Converting darknet to relay functions...") +sym, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape) + +###################################################################### +# Import the graph to Relay +# ------------------------- +# compile the model +target = 'llvm' +target_host = 'llvm' +ctx = tvm.cpu(0) +data = np.empty([batch_size, net.c, net.h, net.w], dtype) +shape = {'data': data.shape} +print("Compiling the model...") +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) + +[neth, netw] = shape['data'][2:] # Current image shape is 608x608 +###################################################################### +# Load a test image +# ----------------- +test_image = 'dog.jpg' +print("Loading the test image...") +img_url = REPO_URL + 'data/' + test_image + '?raw=true' +img_path = download_testdata(img_url, test_image, "data") + +data = tvm.relay.testing.darknet.load_image(img_path, netw, neth) +###################################################################### +# Execute on TVM Runtime +# ---------------------- +# The process is no different from other examples. +from tvm.contrib import graph_runtime + +m = graph_runtime.create(graph, lib, ctx) + +# set inputs +m.set_input('data', tvm.nd.array(data.astype(dtype))) +m.set_input(**params) +# execute +print("Running the test image...") + +m.run() +# get outputs +tvm_out = [] +if MODEL_NAME == 'yolov2': + layer_out = {} + layer_out['type'] = 'Region' + # Get the region layer attributes (n, out_c, out_h, out_w, classes, coords, background) + layer_attr = m.get_output(2).asnumpy() + layer_out['biases'] = m.get_output(1).asnumpy() + out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0], + layer_attr[2], layer_attr[3]) + layer_out['output'] = m.get_output(0).asnumpy().reshape(out_shape) + layer_out['classes'] = layer_attr[4] + layer_out['coords'] = layer_attr[5] + layer_out['background'] = layer_attr[6] + tvm_out.append(layer_out) + +elif MODEL_NAME == 'yolov3': + for i in range(3): + layer_out = {} + layer_out['type'] = 'Yolo' + # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total) + layer_attr = m.get_output(i*4+3).asnumpy() + layer_out['biases'] = m.get_output(i*4+2).asnumpy() + layer_out['mask'] = m.get_output(i*4+1).asnumpy() + out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0], + layer_attr[2], layer_attr[3]) + layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape) + layer_out['classes'] = layer_attr[4] + tvm_out.append(layer_out) + +# do the detection and bring up the bounding boxes +thresh = 0.5 +nms_thresh = 0.45 +img = tvm.relay.testing.darknet.load_image_color(img_path) +_, im_h, im_w = img.shape +dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, + 1, tvm_out) +last_layer = net.layers[net.n - 1] +tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) + +coco_name = 'coco.names' +coco_url = REPO_URL + 'data/' + coco_name + '?raw=true' +font_name = 'arial.ttf' +font_url = REPO_URL + 'data/' + font_name + '?raw=true' +coco_path = download_testdata(coco_url, coco_name, module='data') +font_path = download_testdata(font_url, font_name, module='data') + +with open(coco_path) as f: + content = f.readlines() + +names = [x.strip() for x in content] + +tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) +plt.imshow(img.transpose(1, 2, 0)) +plt.show() From 51a2d641b36053c5320aa1525ac0616b951bba6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 24 May 2019 16:42:29 -0700 Subject: [PATCH 045/176] [Relay] remove unneeded VisitExpr (#3239) --- src/relay/pass/gradient.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 5c5ea01ac2f3..91072b31a910 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -279,7 +279,7 @@ struct ReverseAD : ExprMutator { } std::vector orig_args; for (const auto& arg : args) { - orig_args.push_back(GetField(VisitExpr(arg), 0)); + orig_args.push_back(GetField(arg, 0)); } Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); Var orig_var = ll->Push(orig); From f8c4cb94d7a52d2702e0b8a3bd5e949de4d0f564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 24 May 2019 16:43:03 -0700 Subject: [PATCH 046/176] [Relay] Start porting pass to the pass manager (#3191) --- include/tvm/relay/pass.h | 139 +++++++++++++++--------- include/tvm/relay/transform.h | 120 +++++++++++++++++++- src/relay/pass/dead_code.cc | 12 ++ src/relay/pass/device_annotation.cc | 12 ++ src/relay/pass/fold_constant.cc | 12 ++ src/relay/pass/forward_rewrite.cc | 31 ++++++ src/relay/pass/fuse_ops.cc | 14 +++ src/relay/pass/partial_eval.cc | 12 ++ src/relay/pass/pass_manager.cc | 17 ++- src/relay/pass/to_a_normal_form.cc | 12 ++ src/relay/pass/to_graph_normal_form.cc | 12 ++ tests/python/relay/test_pass_manager.py | 4 +- 12 files changed, 328 insertions(+), 69 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index c84e3f952de4..67cc5df82407 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod, */ TVM_DLL Kind KindCheck(const Type& t, const Module& mod); -/*! \brief Compare two expressions for structural equivalence. +/*! + * \brief Compare two expressions for structural equivalence. * * This comparison operator respects scoping and compares * expressions without regard to variable choice. @@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod); */ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); -/*! \brief Compare two types for structural equivalence. +/*! + * \brief Compare two types for structural equivalence. * * This comparison operator respects scoping and compares * expressions without regard to variable choice. @@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); */ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); -/*! \brief Add abstraction over a function +/*! + * \brief Add abstraction over a function * * For example: `square` is transformed to * `fun x -> square x`. @@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); */ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); -/*! \brief Check that each Var is only bound once. +/*! + * \brief Check that each Var is only bound once. * * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * @@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); */ TVM_DLL bool WellFormed(const Expr& expr); -/*! \brief Get all bound variables from expression expr. +/*! + * \brief Get all bound variables from expression expr. * * Bound variables are all variables that are declared in the expr. * They only have meaning inside that expr, and can only be used in it. @@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr); */ TVM_DLL tvm::Array BoundVars(const Expr& expr); -/*! \brief Get all bound variables from pattern pat. +/*! + * \brief Get all bound variables from pattern pat. * * Bound variables are all variables that got bound by the pat. * They only have meaning inside that expr, and can only be used in it. @@ -170,7 +177,8 @@ TVM_DLL tvm::Array BoundVars(const Expr& expr); */ TVM_DLL tvm::Array BoundVars(const Pattern& pat); -/*! \brief Get free type parameters from expression expr. +/*! + * \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a * let or a function parameter in the context. @@ -181,7 +189,8 @@ TVM_DLL tvm::Array BoundVars(const Pattern& pat); */ TVM_DLL tvm::Array FreeVars(const Expr& expr); -/*! \brief Get all variables from expression expr. +/*! + * \brief Get all variables from expression expr. * * \param expr the expression. * @@ -189,7 +198,8 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); */ TVM_DLL tvm::Array AllVars(const Expr& expr); -/*! \brief Get free TypeVars from expression expr. +/*! + * \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function * type in the context. @@ -201,7 +211,8 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); */ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); -/*! \brief Get free TypeVars from type t. +/*! + * \brief Get free TypeVars from type t. * * Free type parameters are type parameters that are not bound by a function * type in the context. @@ -213,7 +224,8 @@ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); -/*! \brief Get all bound type variables from expression expr. +/*! + * \brief Get all bound type variables from expression expr. * * Bound variables are all type variables that are declared in the expr. * They only have meaning inside that expr, and can only be used in it. @@ -225,7 +237,8 @@ TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); */ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); -/*! \brief Get all bound type variables from type t. +/*! + * \brief Get all bound type variables from type t. * * Bound variables are all type variables that are declared in the type. * They only have meaning inside that type, and can only be used in it. @@ -237,7 +250,8 @@ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); -/*! \brief Get all type variables in expression expr. +/*! + * \brief Get all type variables in expression expr. * * \param expr the expression. * \param mod the module. @@ -246,7 +260,8 @@ TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); -/*! \brief Get all type variables in type t. +/*! + * \brief Get all type variables in type t. * * \param t the type. * \param mod the module. @@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e); /*! * \brief Fold constant expressions. + * * \param expr the expression to be optimized. + * * \return The optimized expression. */ TVM_DLL Expr FoldConstant(const Expr& expr); /*! * \brief Fuse operations into expr into seperate functions. + * * \param expr The expression. * \param fuse_opt_level Optimization level. * \param mod the module. + * * \return The optimized expression. */ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * * \param expr The expression. * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite * rule function. @@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); * \return The rewritten expression. */ TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); + const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * * \param expr The expression. * \param rewrite_func The rewrite func that will apply to all operators. * \param fcontext Additional callback to provide context argument for each call node. * \param fmulti_ref_trigger Transformation function to be called when * an Expr consumed by multiple callers. + * * \return The rewritten expression. */ TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); /*! * \brief Rewrite the annotated program. + * * \param expr The expression. * \param fallback_device The fallback device which is the default device for * operators without annotation. + * * \return The updated program. */ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); /*! * \brief Collect the device mapping information of each expression. + * * \param expr The expression. + * * \return The device mapping. */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); -/*! \brief A hashing structure in the style of std::hash. */ -struct StructuralHash { - /*! \brief Hash a Relay type. - * - * Implements structural hashing of a Relay type. - * - * \param type the type to hash. - * - * \return the hash value. - */ - size_t operator()(const Type& type) const; - - /*! \brief Hash a Relay expression. - * - * Implements structural hashing of a Relay expression. - * - * \param expr the expression to hash. - * - * \return the hash value. - */ - size_t operator()(const Expr& expr) const; -}; - -/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). +/*! + * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). * * It will turn an expression that is in a graph form (with sharing implicit), * to an expression with explicit sharing (A-Normal Form). * * The scope of the root expression is the global scope. - + * * The scope of any non root expression is the least common ancestor of all it's scope. * * Values are ordered by post-DFS order in each scope. * - * \param e the expression to observably share - * + * \param e the expression to observably share. * \param mod The module used for referencing global functions, can be * None. * - * \return expression in A-Normal Form + * \return expression in A-Normal Form. */ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); -/*! \brief Remove let binding and directly share via pointer instead. +/*! + * \brief Remove let binding and directly share via pointer instead. * * It will remove all let binding, * and turn all of the variable bound by let into direct pointer reference. @@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); */ TVM_DLL Expr ToGraphNormalForm(const Expr& e); -/*! \brief Aggressive constant propagation/constant folding/inlining. +/*! + * \brief Aggressive constant propagation/constant folding/inlining. + * * It will do as much computation in compile time as possible. * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * As a side effect, code size will explode. + * + * \param e the expression, + * + * \return the optimized expression. */ -Expr PartialEval(const Expr& e); +TVM_DLL Expr PartialEval(const Expr& e); + +/*! \brief A hashing structure in the style of std::hash. */ +struct StructuralHash { + /*! \brief Hash a Relay type. + * + * Implements structural hashing of a Relay type. + * + * \param type the type to hash. + * + * \return the hash value. + */ + size_t operator()(const Type& type) const; + + /*! \brief Hash a Relay expression. + * + * Implements structural hashing of a Relay expression. + * + * \param expr the expression to hash. + * + * \return the hash value. + */ + size_t operator()(const Expr& expr) const; +}; namespace vm { -/*! \brief Compile a module, and construct the virtual machine. +/*! + * \brief Compile a module, and construct the virtual machine. * * \param mod The module to compile. + * * \return The constructed virtual machine. */ runtime::vm::VirtualMachine CompileModule(const Module& mod); diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 5123f3a3dcf3..4d6921a6b860 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -61,6 +61,7 @@ #include #include #include +#include #include #include #include @@ -198,7 +199,7 @@ class Pass; */ class PassNode : public RelayNode { public: - /* + /*! * \brief Get the pass information/meta data. */ virtual PassInfo Info() const = 0; @@ -300,11 +301,118 @@ Pass CreateModulePass( * * \return The created function pass. */ -Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); +TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< + Function(Function, Module, PassContext)>& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +/*! \brief Remove expressions which does not effect the program result. + * + * It will remove let bindings which are not referenced, + * and inline let bindings that are only used once. + * + * For example, this pass should turn `let a = 1 in 2` into `2`, + * as the value of the expression does not depend on a. + * + * As another example, `let a = 1 in a` will be optimized into 1. + * + * \return the pass. + */ +TVM_DLL Pass DeadCodeElimination(); + +/*! + * \brief Fold constant expressions. + * + * \return The pass. + */ +TVM_DLL Pass FoldConstant(); + +/*! + * \brief Fuse operations into expr into seperate functions. + * + * \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context. + * + * \return The pass. + */ +TVM_DLL Pass FuseOps(int fuse_opt_level = -1); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * + * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite + * rule function. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * + * \return The pass. + */ +TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function + fmulti_ref_trigger = nullptr); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * + * \param rewrite_func The rewrite func that will apply to all operators. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * + * \return The pass. + */ +TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +/*! + * \brief Rewrite the annotated program. + * + * \param fallback_device The fallback device which is the default device for + * operators without annotation. + * + * \return The pass. + */ +TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); + +/*! + * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). + * + * It will turn an expression that is in a graph form (with sharing implicit), + * to an expression with explicit sharing (A-Normal Form). + * + * The scope of the root expression is the global scope. + * + * The scope of any non root expression is the least common ancestor of all it's scope. + * + * Values are ordered by post-DFS order in each scope. + * + * \return The pass. + */ +TVM_DLL Pass ToANormalForm(); + +/*! + * \brief Remove let binding and directly share via pointer instead. + * + * It will remove all let binding, + * and turn all of the variable bound by let into direct pointer reference. + * + * \return the expression in graph normal form. + */ +TVM_DLL Pass ToGraphNormalForm(); + +/*! + * \brief Aggressive constant propagation/constant folding/inlining. + * + * It will do as much computation in compile time as possible. + * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). + * As a side effect, code size will explode. + * + * \return the optimized expression. + */ +TVM_DLL Pass PartialEval(); } // namespace transform } // namespace relay diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 533c21429995..dd1ed6240cab 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -151,5 +151,17 @@ Expr DeadCodeElimination(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") .set_body_typed(DeadCodeElimination); +namespace transform { + +Pass DeadCodeElimination() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(DeadCodeElimination(f)); + }; + return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 8807f6dd4cf4..fa656dbf489e 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -550,6 +550,18 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); +namespace transform { + +Pass RewriteAnnotatedOps(int fallback_device) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(RewriteAnnotatedOps(f, fallback_device)); + }; + return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index c085d80d06e2..286392ab5d3f 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -215,5 +215,17 @@ Expr FoldConstant(const Expr& expr) { TVM_REGISTER_API("relay._ir_pass.FoldConstant") .set_body_typed(FoldConstant); +namespace transform { + +Pass FoldConstant() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(FoldConstant(f)); + }; + return CreateFunctionPass(pass_func, 1, "fold_constant", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 88a2d669da9f..2a3aa1612418 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr, return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } +namespace transform { + +using std::function; + +Pass ForwardRewrite(const std::string& rewrite_map_attr_name, + function fcontext, + function fmulti_ref_trigger) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(ForwardRewrite(f, + rewrite_map_attr_name, + fcontext, + fmulti_ref_trigger)); + }; + return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); +} + +Pass ForwardRewrite(const FForwardRewrite& rewrite_func, + function fcontext, + function fmulti_ref_trigger) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(ForwardRewrite(f, + rewrite_func, + fcontext, + fmulti_ref_trigger)); + }; + return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); +} + +} // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index d0d0cab22432..9277689075c2 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -964,5 +964,19 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { TVM_REGISTER_API("relay._ir_pass.FuseOps") .set_body_typed(FuseOps); + +namespace transform { + +Pass FuseOps(int fuse_opt_level) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + return Downcast(FuseOps(f, opt_level, m)); + }; + return CreateFunctionPass(pass_func, 1, "fuse_ops", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index ad861743dfd5..3f42c6fce4b2 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -801,5 +801,17 @@ TVM_REGISTER_API("relay._ir_pass.partial_evaluate") *ret = PartialEval(args[0]); }); +namespace transform { + +Pass PartialEval() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(PartialEval(f)); + }; + return CreateFunctionPass(pass_func, 1, "partial_eval", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 4bcc0bb39cc4..ea4c976b7db5 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -201,7 +201,7 @@ class FunctionPassNode : public PassNode { * `pass_func` and let it run on a given module. The same `pass_func` will * then be applied on each function in the module. */ - runtime::TypedPackedFunc pass_func; + runtime::TypedPackedFunc pass_func; FunctionPassNode() = default; @@ -225,7 +225,7 @@ class FunctionPassNode : public PassNode { PassInfo Info() const { return pass_info; } TVM_DLL static FunctionPass make( - runtime::TypedPackedFunc pass_func, + runtime::TypedPackedFunc pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.FunctionPass"; @@ -363,7 +363,7 @@ Module ModulePassNode::operator()(const Module& mod, } FunctionPass FunctionPassNode::make( - runtime::TypedPackedFunc pass_func, + runtime::TypedPackedFunc pass_func, PassInfo pass_info) { auto n = make_node(); n->pass_func = std::move(pass_func); @@ -383,8 +383,7 @@ Module FunctionPassNode::operator()(const Module& mod, // Execute the pass function and return a new module. for (const auto& it : mod->functions) { - auto updated_func = - SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx); + auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx); new_mod->Add(it.first, updated_func); } @@ -501,7 +500,7 @@ Pass CreateModulePass( } Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, + const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, const tvm::Array& required) { @@ -589,7 +588,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) tvm::IRPrinter* p) { const PassInfoNode* seq_pn = node->Info().operator->(); p->stream << "Run Sequential pass: " << seq_pn->name - << " at the optimization level. " << seq_pn->opt_level; + << " at the optimization level " << seq_pn->opt_level << ". "; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { const PassNode* pn = it.operator->(); diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 913f8de05d7b..f9d47f78a6d2 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -333,5 +333,17 @@ Expr ToANormalForm(const Expr& e, const Module& m) { TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") .set_body_typed(static_cast(ToANormalForm)); +namespace transform { + +Pass ToANormalForm() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(ToANormalForm(f, m)); + }; + return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 490a80f308ce..50ebb702e4b2 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -79,5 +79,17 @@ Expr ToGraphNormalForm(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") .set_body_typed(ToGraphNormalForm); +namespace transform { + +Pass ToGraphNormalForm() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(ToGraphNormalForm(f)); + }; + return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); +} + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index db346e7f712f..2703e5ce1679 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -204,7 +204,7 @@ def test_function_pass(): pass_ctx = None @_transform.function_pass(opt_level=opt_level, name=pass_name) - def transform(expr, ctx): + def transform(expr, mod, ctx): return opt_tester.transform(expr, ctx) def get_ref_log(): @@ -303,7 +303,7 @@ def mod_transform(expr, ctx): # Register a function pass. @_transform.function_pass(opt_level=1) - def func_transform(expr, ctx): + def func_transform(expr, mod, ctx): return opt_tester.transform(expr, ctx) function_pass = func_transform From 4b8bae30c132093418cdf230532fa69027ffd5b1 Mon Sep 17 00:00:00 2001 From: Gemfield Date: Sat, 25 May 2019 08:48:58 +0800 Subject: [PATCH 047/176] Fixed a typo (#3218) * Fixed a typo * Remove outdated url link. --- apps/android_rpc/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/android_rpc/README.md b/apps/android_rpc/README.md index 36ace85405c0..38725917f424 100644 --- a/apps/android_rpc/README.md +++ b/apps/android_rpc/README.md @@ -141,7 +141,7 @@ export TVM_NDK_CC=/opt/android-toolchain-arm64/bin/aarch64-linux-android-g++ python android_rpc_test.py ``` -This will compile TVM IR to shared libraries (CPU, OpenCL and Vulkan) and run vector addition on your Android device. To verify compiled TVM IR shared libraries on OpenCL target set [`'test_opencl = True'`](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py#L25) and on Vulkan target set [`'test_vulkan = False'`](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py#L27) in [tests/android_rpc_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py), by default on CPU target will execute. +This will compile TVM IR to shared libraries (CPU, OpenCL and Vulkan) and run vector addition on your Android device. To verify compiled TVM IR shared libraries on OpenCL target set `'test_opencl = True'` and on Vulkan target set `'test_vulkan = True'` in [tests/android_rpc_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py), by default on CPU target will execute. On my test device, it gives following results. ```bash From a6fc910091c3c1b38431c5ca8c82eb6ac3fb09a2 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sat, 25 May 2019 17:40:02 -0700 Subject: [PATCH 048/176] [Relay][Frontend] Add Crop op converter (#3241) * Add Crop op converter * lint * x --- nnvm/python/nnvm/frontend/mxnet.py | 2 +- python/tvm/relay/frontend/mxnet.py | 32 +++++++++++++++++++-- tests/python/frontend/mxnet/test_forward.py | 26 +++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 77671225aa3e..6f6bfc87ea8a 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -269,7 +269,7 @@ def _crop_like(inputs, attrs): raise tvm.error.OpAttributeUnimplemented( 'Center crop is not supported in operator crop_like.') if len(inputs) < 2: - raise RuntimeError("Only support crop_like pattern.") + raise tvm.error.OpAttributeUnimplemented("Only support crop_like pattern.") new_attrs["axis"] = [2, 3] return get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1a4d52f5b679..0bc7923648ff 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -149,7 +149,7 @@ def _mx_conv2d_transpose(inputs, attrs): new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["data_layout"] = data_layout new_attrs["kernel_layout"] = kernel_layout - use_bias = not attrs.get_bool("no_bias", False) + use_bias = not attrs.get_bool("no_bias", True) res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs) if use_bias: @@ -277,6 +277,28 @@ def _mx_slice_axis(inputs, attrs): return _op.strided_slice(inputs[0], begin, end) +def _mx_crop_like(inputs, attrs): + if len(inputs) < 2: + raise tvm.error.OpAttributeUnimplemented( + "Only support crop_like pattern for operator Crop.") + if attrs.get_bool("center_crop", False): + raise tvm.error.OpAttributeUnimplemented( + "Center crop is not supported in operator Crop.") + if attrs.get_int_tuple("h_w", (0, 0)) != (0, 0): + raise tvm.error.OpAttributeUnimplemented( + "Doesn't support h_w in operator Crop.") + offset = attrs.get_int_tuple("offset", (0, 0)) + new_attrs = {} + if offset == (0, 0): + new_attrs["axes"] = (2, 3) + return _op.slice_like(*inputs, **new_attrs) + like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape + new_attrs['begin'] = [0, 0, offset[0], offset[1]] + new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2], + offset[1]+like_shape[3]] + return _op.strided_slice(inputs[0], **new_attrs) + + def _mx_split(inputs, attrs): axis = attrs.get_int("axis", 1) new_attrs = {} @@ -300,6 +322,10 @@ def _mx_softmax_output(inputs, attrs): return _op.nn.softmax(inputs[0]) +def _mx_linear_regression_output(inputs, _): + return inputs[0] + + def _mx_concat(inputs, attrs): axis = attrs.get_int("dim", 1) return _op.concatenate(tuple(inputs), axis=axis) @@ -890,6 +916,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "argsort" : _mx_argsort, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, + "LinearRegressionOutput" : _mx_linear_regression_output, "smooth_l1" : _mx_smooth_l1, # vision "_contrib_BilinearResize2D" : _mx_resize, @@ -905,11 +932,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): # NLP "RNN" : _mx_rnn_layer, "_rnn_param_concat" : _mx_rnn_param_concat, + # Depricated: + "Crop" : _mx_crop_like, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # # "broadcast_to", - # "Crop" : _crop_like, } # set identity list diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index e75e60da5ce4..50a25a9aff61 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -583,6 +583,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): verify(mode, 64, 10, 64, 2) verify(mode, 64, 10, 32, 2) +def test_forward_Crop(): + def verify(xshape, yshape, offset=None): + x_data = np.random.uniform(size=xshape).astype("float32") + y_data = np.random.uniform(size=yshape).astype("float32") + if offset is None: + mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y")) + ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data)) + else: + mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset) + ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + if offset is None or offset == (0, 0): + op_res = intrp.evaluate(new_sym)(x_data, y_data) + else: + op_res = intrp.evaluate(new_sym)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((1, 3, 40, 40), (1, 3, 20, 20)) + verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0)) + verify((1, 3, 40, 40), (1, 3, 20, 20), (10, 10)) + verify((5, 32, 40, 40), (5, 32, 25, 25)) + verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5)) + if __name__ == '__main__': test_forward_mlp() @@ -624,3 +649,4 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): test_forward_gather_nd() test_forward_bilinear_resize() test_forward_rnn_layer() + test_forward_Crop() From cbc719b20470357d24219828a9df55cf50c4091a Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Mon, 27 May 2019 19:33:13 +0300 Subject: [PATCH 049/176] [ARITH] Improve div/mod in rewrite simplifier (#3149) * [ARITH] Improve div/mod in rewrite simplifier * Fix lint error * Fuller file name in src/arithmetic/modular_set.h Co-Authored-By: Wei Chen * Generalize some rules * Replace gcd factoring with specialized rules * Mark rules that don't work for non-truncated division * More tests --- src/arithmetic/modular_set.cc | 2 + src/arithmetic/rewrite_simplify.cc | 102 +++++++++++++++--- .../unittest/test_arith_rewrite_simplify.py | 69 ++++++++++++ 3 files changed, 161 insertions(+), 12 deletions(-) diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 57e82943b84c..b3e943fc7631 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include "pattern_match.h" namespace tvm { diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 0de2a2535ae7..00198d9b140a 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -80,12 +80,6 @@ TryCompare(const Expr& x, int64_t val) { return kLT; } } - if (val == 0) { - ModularSet dmod = parent_->modular_set(diff); - if (dmod->base != 0) { - return kNE; - } - } ConstIntBound dbound = parent_->const_int_bound(diff); if (dbound->min_value > val) { return kGT; @@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) { if (dbound->max_value <= val) { return kLE; } + if (val == 0) { + ModularSet dmod = parent_->modular_set(diff); + if (dmod->base != 0) { + return kNE; + } + } return kUnknown; } @@ -284,11 +284,39 @@ Mutate_(const Sub* op, const Expr& self) { CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); // modular-div simplification - // Always pre-condition on positive integer domain + // Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1, - CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + c1.Eval()->value != 0); TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1), - CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y, + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1), + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y, + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y, + c1.Eval()->value != 0); + + TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && @@ -348,6 +376,7 @@ Mutate_(const Mul* op, const Expr& self) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); + TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); TVM_TRY_RECURSIVE_REWRITE_IF( (x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); @@ -396,6 +425,16 @@ Mutate_(const Div* op, const Expr& self) { // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. + // TryConstFold doesn't work for negative cases because it is also used by legacy + // parts of tvm which still assume euclidean div. In this simplifier we assume that the division + // is truncated, so perform const folding again. + // NOTE: trunc div required + if ((c1 / c2).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + return make_const(op->type, c1val / c2val); + } + // while it is always true for trunc div // restrict to common case(positive div) TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2), @@ -608,6 +647,12 @@ Mutate_(const Mod* op, const Expr& self) { CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + // canonicalization: x % c == x % (-c) for truncated division + // NOTE: trunc div required + TVM_TRY_RECURSIVE_REWRITE_IF(x % c1, + x % PConst(make_const(op->type, -c1.Eval()->value)), + c1.Eval()->value < 0); + // try modular analysis if ((x % c1).Match(ret)) { ModularSet mod = parent_->modular_set(x.Eval()); @@ -1025,20 +1070,53 @@ Mutate_(const LT* op, const Expr& self) { TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0); - // require c1 > 0 to work for any div mode TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2, + c1.Eval()->value <= 0 && + c2.Eval()->value > 0); + // NOTE: trunc div required (euclidean is ok too, floored is not) + TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x, c1.Eval()->value > 0 && + c2.Eval()->value < 0); + // NOTE: trunc div required (floored is ok too, euclidean is not) + TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x, + c1.Eval()->value <= 0 && + c2.Eval()->value < 0); + + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x, + c1.Eval()->value < 0 && c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x, c1.Eval()->value >= 0 && c2.Eval()->value > 0); + // NOTE: trunc div required (floored is ok too, euclidean is not) + TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1, + c1.Eval()->value < 0 && + c2.Eval()->value < 0); + // NOTE: trunc div required (euclidean is ok too, floored is not) + TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2, + c1.Eval()->value >= 0 && + c2.Eval()->value < 0); + + TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, + c1.Eval()->value > 0 && + c2.Eval()->value > 0); + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1, + c1.Eval()->value > 0 && + c2.Eval()->value <= 0); + TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x, c1.Eval()->value >= 0 && c2.Eval()->value > 0); + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x, + c1.Eval()->value < 0 && + c2.Eval()->value > 0); // division related simplificationx // invariance for any div mod: x - (x / c1) * c1 == x % c1 diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index be961a5c6543..1b03253c9a0f 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -227,6 +227,25 @@ def test_sub_index_simplify(): ck.verify(x - (x / 3) * 3, x % 3) ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3) + ck.verify(y - (y / (-5)) * (-5), y % 5) + ck.verify((y / 3) * 3 - y, 0 - y % 3) + ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6) + ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5) + ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z) + ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5) + ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z) + ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z) + + ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3) + ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2)) + ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) + ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5) + ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2) + ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2) + ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) + ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5) + ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2) + ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2) def test_mul_index_simplify(): ck = RewriteChecker() @@ -292,6 +311,11 @@ def test_mod_index_simplify(): ck.verify((x + 10) % 2, x % 2) ck.verify((x + y * 10) % 2, x % 2) ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1) + ck.verify(x * 10 % -2, 0) + ck.verify((x * 10 + y) % -2, y % 2) + ck.verify((x + 10) % -2, x % 2) + ck.verify((x + y * 10) % -2, x % 2) + ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1) def test_min_index_simplify(): @@ -449,6 +473,50 @@ def test_cmp_simplify(): ck.verify(x / 2 < 3, x < 6) ck.verify(x * 4 <= 2, x <= 0) ck.verify(3 < x / 2, tvm.expr.LT(7, x)) + ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x)) + ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x)) + ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0)) + ck.verify(2 * x <= 0, x <= 0) + + ck.verify(x * 2 >= 3, tvm.expr.LE(2, x)) + ck.verify(x * 2 >= 2, tvm.expr.LE(1, x)) + ck.verify(x * 2 >= 1, tvm.expr.LE(1, x)) + ck.verify(x * 2 >= 0, tvm.expr.LE(0, x)) + ck.verify(x * 2 >= -1, tvm.expr.LE(0, x)) + ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x)) + ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x)) + + ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1)) + ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1)) + ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0)) + ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0)) + ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1)) + ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1)) + ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2)) + + ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2)) + ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1)) + ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1)) + ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0)) + ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0)) + ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1)) + ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1)) + + ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x)) + ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x)) + ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x)) + ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x)) + ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x)) + ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x)) + ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x)) + + ck.verify(x / 2 >= 1, tvm.expr.LE(2, x)) + ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x)) + ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x)) + + ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3)) + ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1)) + ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2)) ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) @@ -480,6 +548,7 @@ def test_cmp_simplify(): ck.verify(x*y <= 0, tvm.const(1, "bool")) ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) ck.verify(y*y >= 0, tvm.const(1, "bool")) + ck.verify(x*6 <= -3, tvm.const(0, "bool")) def test_logical_simplify(): From f3b4c80e3f8574f46f74a929bf6628ef38b96c74 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Tue, 28 May 2019 06:24:54 +0800 Subject: [PATCH 050/176] [Doc][Relay] Add VM doc (#3188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Doc][Relay] Add VM doc * Add Apache header * Apply suggestions from code review Co-Authored-By: Steven S. Lyubomirsky Co-Authored-By: 雾雨魔理沙 Co-Authored-By: Logan Weber <36520469+weberlo@users.noreply.github.com> Co-Authored-By: Zhi <5145158+zhiics@users.noreply.github.com> * Junru's comment * More fix * More fix * More fix * last fix * Apply suggestions from code review Co-Authored-By: 雾雨魔理沙 * Apply suggestions from code review Co-Authored-By: Logan Weber <36520469+weberlo@users.noreply.github.com> * Add code links * Remove unused bp * Update docs/dev/virtual_machine.rst Co-Authored-By: Logan Weber <36520469+weberlo@users.noreply.github.com> * Explain TODO * Yong's comment Co-Authored-By: Yong Wu <55wuyong@163.com> * Comment --- docs/dev/virtual_machine.rst | 314 +++++++++++++++++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 docs/dev/virtual_machine.rst diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst new file mode 100644 index 000000000000..a59620a0a861 --- /dev/null +++ b/docs/dev/virtual_machine.rst @@ -0,0 +1,314 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Putting the VM in TVM: The Relay Virtual Machine +================================================ + +Relay, a new program representation, has enabled the representation and optimization of +a great breadth of machine learning programs. +Unfortunately, by supporting a more expressive set of programs, we have +introduced several new execution challenges. + +Relay's interpreter can execute the full language but has notable limitations +that make it unsuited for production deployments. It is structured as an inefficient +interpreter that performs AST traversal to execute the program. This approach is conceptually +simple but inefficient, as the AST traversal heavily relies on indirection. + +There are further challenges in compiling dynamic code, such as dynamic scheduling and allocation, +fully dynamic tensor shapes, and control flow. The interpreter offers simple solutions +for these, but none is sufficiently compelling or optimized. + +The second execution mechanism is the existing graph runtime. In order to target Relay +programs to this, we compile a small subset of them to the old graph format and execute +them on the runtime. Graph runtime provides a fast execution experience but only for a very limited +subset of Relay programs. + +An alternative but not-standard approach is Relay's ahead-of-time compiler, +which compiles a Relay program into a shared library containing an ahead- +of-time implementation. The ahead-of-time compiler provides compelling performance +but is difficult to extend and instrument, which can only be done by modifying the +code generation and optimization mechanisms. + +The Relay virtual machine is intended to be a framework that balances these competing +approaches, providing a dynamic execution environment which can be extended, instrumented, +and integrated with other approaches like ahead-of-time compilation via a flexible extension +mechanism. + +The virtual machine is designed to strike a balance between performance and flexibility +when deploying and executing Relay programs, without giving up the benefits of TVM. + +Virtual machine (VM) design is a well-studied area in programming languages and systems, +and there have been various virtual machine designs for both full-fledged +and embedded programing languages. +Previous language VM designs have been heavily tailored to the execution profile of traditional programs. +Traditional programs manipulate small scalar values and consist of a large number of low-level instructions. +The sheer quantity of instructions requires instruction execution and dispatch to be extremely efficient. +In the context of machine learning we manipulate primarily tensor values, using a (relatively) +low number of high level instructions. ML programs' cost centers are expensive operator invocations, +such as GEMM or convolution, over a large input. Due to the execution profile exhibited by ML programs, +micro-optimizations present in scalar VMs are dramatically less important. + +TVM has provided strong support for vision models, +but we want to grow to support a wider variety of models. +The graph runtime is able to utilize the fully static nature of the input graphs to perform +aggressive optimization such as fully static allocation, and optimal memory reuse. +When we introduce models which make use of control flow, recursion, dynamic shapes, and dynamic +allocation, we must change how execution works. A virtual machine for Relay is a natural choice. + +The rest of this document provides a high-level overview of the Relay +virtual machine design and its instruction set. + +Design +------ + +The VM's design is focused on simplicity without sacrificing performance. +In order to accomplish this we have focused on designing a tensor VM rather than a scalar VM. + +In the tensor VM setting, we optimize for cheap “allocation” of objects (by trying to avoid real allocation), +reuse of static fragments, and the ability to do dynamic shape (i.e jagged tensors). + +Instruction Set +~~~~~~~~~~~~~~~ + +The choices of an instruction set and instruction representation are the most critical design decisions for a VM. +The current representation of the instructions is a tagged union containing the op-code and the data payload. An important design decision is the level of abstraction of the instructions (RISC vs. CISC) and how they take their data (fixed-width instruction encoding vs. variable-length encoding). The current version is closer to CISC, with complex instructions like AllocTensor, and is variable-length due to the inclusion of the shape as part of the instruction. The current instruction set is very high-level and corresponds roughly to high-level operations in Relay. + +Ret +^^^ +**Arguments**: +:: + RegName dst + RegName result + +Returns the object in register `result` to caller's register `dst`. + +InvokePacked +^^^^^^^^^^^^ +**Arguments**: +:: + size_t packed_index + size_t arity + size_t output_size + RegName* packed_args + +Invoke the packed function denoted by `packed_index`. The `arity` +and `output_size` are used to inform the VM how many inputs and +outputs to expect. `packed_args` stores the list of argument registers. + +AllocTensor +^^^^^^^^^^^ +**Arguments**: +:: + RegName dst + RegName shape_register + size_t ndim + DLDataType dtype + +Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result +is saved to register `dst`. + +AllocDatatype +^^^^^^^^^^^^^ +**Arguments**: +:: + RegName dst + size_t tag + size_t num_fields + RegName* datatype_fields + +Allocate a data type with the tag `tag` using the `num_fields` entries +from registers `datatype_fields`. The result is saved to register `dst`. + +AllocClosure +^^^^^^^^^^^^ +**Arguments**: +:: + RegName dst + size_t clo_index + size_t num_freevar + RegName* free_vars; + +Allocate a closure with the VMFunction at `clo_index` as +its code, and the `num_freevar` entries from registers in +`free_vars`. The result is saved to register `dst`. + +GetField +^^^^^^^^ +**Arguments**: +:: + RegName dst + RegName object + size_t field_index + +Get the field value with index `field_index` from `object`. And saves the result to register `dst`. + +If +^^ +**Arguments**: +:: + RegName if_cond + size_t true_offset + size_t false_offset + +Check if the object at register `if_cond` is `true` or `false`. +If `true`, relative jump by `true_offset`, else relative +jump by `false_offset`. + +Goto +^^^^ +**Arguments**: +:: + size_t pc_offset + +Relative unconditional jump by `pc_offset`. + +Invoke +^^^^^^ +**Arguments**: +:: + size_t func_index + +Invoke function at `func_index`, consumes the number of arguments contained in the VMFunction's +arity field. + +InvokeClosure +^^^^^^^^^^^^^ +**Arguments**: +:: + RegName closure + size_t closure_args_num + RegName* closure_args + +Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction. + +LoadConst +^^^^^^^^^ +**Arguments**: +:: + RegName dst + size_t const_index + +Load the constant at `const_index` from the constant pool. The result is saved to register `dst`. + +Object Representation +~~~~~~~~~~~~~~~~~~~~~ +We use a simple object representation that uses shared pointers and tagging. +There is a huge space of possible object representations trade-offs, but we +believe micro-optimizing this code has little to no effect on the end-to-end performance. + +:: + + struct ObjectCell { + ObjectTag tag; + ... + }; + + struct Object { + std::shared_ptr ptr; + ... + } + +See `include/tvm/runtime/vm.h` for more details. + +Currently, we support 3 types of objects: tensors, data types, and closures. + +:: + + VMObject VMTensor(const tvm::runtime::NDArray& data); + VMObject VMDatatype(size_t tag, const std::vector& fields); + VMObject VMClosure(size_t func_index, std::vector free_vars); + + +Stack and State +~~~~~~~~~~~~~~~ + +The Relay VM maintains a stack frame, which contains information about how to resume the +previous call. Registers are allocated in a continuous space (virtual register file) for each function. + +We keep track of a set of Relay functions we have called, a pointer into its bytecode, an offset into the byte code (known as the program counter). + +:: + + struct VirtualMachine { + ... + std::vector frames; + ... + // Current function. + size_t func_index; + // Pointer into the current function's instructions. + const Instruction* code; + // Current program counter relative to the code pointer. + size_t pc; + ... + }; + + +Dispatch Loop +~~~~~~~~~~~~~ +A critical piece of a VM is the dispatch loop. The dispatch loop usually dominates the execution time of a +virtual machine, but we have experimentally found this not to be the case for Relay. We have just implemented +a simple `switch`/`goto` dispatch loop which dispatches based on instruction op code. + +This loop is implemented by `VirtualMachine::Run()`. + +VM Compiler +~~~~~~~~~~~ + +An important part of this infrastructure is a compiler from Relay's full IR into a sequence of bytecode. +The VM compiler transforms a `tvm::relay::Module` into a `tvm::relay::vm::VirtualMachine`. The virtual +machine contains a set of compiled functions, the compiled functions are contained in `tvm::relay::vm::Function`. The functions contain metadata about the the function as well as its compiled bytecode. For full definitions of the data structures see `vm.h`. + +Optimizations +~~~~~~~~~~~~~ + +There are quite a few optimizations required by the VM compiler. + +We have implemented them in the old pass style, but plan to port them to +the new pass manager (#2546) before merging. + +Optimizations marked with `TODO` are not implemented yet. + +- A-Normal Form +- Lambda Lift (see `src/relay/vm/lambda_lift.cc`) +- Inline Primitives (see `src/relay/vm/inline_primitives.cc`) +- Inliner (see `src/relay/pass/inliner.cc`) +- Constant Pool Layout (see `src/relay/backend/vm/compiler.cc`) +- ADT Tag Allocation (see `src/relay/backend/vm/compiler.cc`) +- Tail Call Optimization (TODO) +- Liveness Analysis (TODO) + +Serialization +~~~~~~~~~~~~~ + +A final and yet-to-be-implemented part of the VM design is serialization. The accompanying PR will introduce both the bytecode and its serialization, as well as VM-level serialization. The design premise is that a VM can be efficiently stored to disk and resumed at a later time. This would also allow us to efficiently schedule many models on to a single machine in order to obtain good utilization. + +Unresolved Questions +~~~~~~~~~~~~~~~~~~~~ + +How do we handle dynamic shapes? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +TODO + +How can we modify the VM to support JIT compilation of certain code paths? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +In the code generation space there are still many tradeoffs to be analyzed and the VM is designed +to be very flexible so we can modify it for future experiments. + +How do we support heterogenous execution? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Heterogenous execution should work out of the box assuming we have annotated the appropriate device copies. +In order to do this properly we need to run the device annotation and copying passes. From 8bd604689d16b877a4da4c07c2ca18516f55712e Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Mon, 27 May 2019 21:29:55 -0700 Subject: [PATCH 051/176] [VTA][TSIM] Use Module instead of RawModule for testbench by creating an empty bundle for the IO (#3242) * use Module instead of RawModule for testbench by creating an empty bundle for the IO * change default back to verilog --- .../chisel/src/test/scala/dut/TestAccel.scala | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/vta/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala b/vta/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala index 45f81d50a50b..2c02ff36a631 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala @@ -20,7 +20,6 @@ package test import chisel3._ -import chisel3.experimental.{RawModule, withClockAndReset} import vta.dpi._ import accel._ @@ -29,21 +28,19 @@ import accel._ * Instantiate Host and Memory DPI modules. * */ -class VTASimShell extends RawModule { +class VTASimShell extends Module { val io = IO(new Bundle { - val clock = Input(Clock()) - val reset = Input(Bool()) val host = new VTAHostDPIMaster val mem = new VTAMemDPIClient }) val host = Module(new VTAHostDPI) val mem = Module(new VTAMemDPI) - mem.io.reset := io.reset - mem.io.clock := io.clock - host.io.reset := io.reset - host.io.clock := io.clock - io.mem <> mem.io.dpi + mem.io.dpi <> io.mem + mem.io.reset := reset + mem.io.clock := clock io.host <> host.io.dpi + host.io.reset := reset + host.io.clock := clock } /** Test accelerator. @@ -51,15 +48,10 @@ class VTASimShell extends RawModule { * Instantiate and connect the simulation-shell and the accelerator. * */ -class TestAccel extends RawModule { - val clock = IO(Input(Clock())) - val reset = IO(Input(Bool())) - +class TestAccel extends Module { + val io = IO(new Bundle {}) val sim_shell = Module(new VTASimShell) - val vta_accel = withClockAndReset(clock, reset) { Module(new Accel) } - - sim_shell.io.clock := clock - sim_shell.io.reset := reset + val vta_accel = Module(new Accel) vta_accel.io.host <> sim_shell.io.host sim_shell.io.mem <> vta_accel.io.mem } From f92888727e088a8d88c87529704cf33a2124e260 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 29 May 2019 02:02:48 +0800 Subject: [PATCH 052/176] Move CombineParallelConv2D to opt level 4 (#3248) --- src/relay/backend/build_module.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 3b1491072d25..57dc256ef6b7 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -65,7 +65,7 @@ const std::unordered_map OptPassLevel::_data = { {"SimplifyInference", 0}, {"OpFusion", 1}, {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, + {"CombineParallelConv2D", 4}, {"FoldScaleAxis", 3}, {"AlterOpLayout", 3}, {"CanonicalizeOps", 3}, From d89257a0b8750886be0c8efa1ae063ed3950389c Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Tue, 28 May 2019 13:35:09 -0700 Subject: [PATCH 053/176] kCustomBegin overlapped with kExtEnd; incr by 1 (#3250) This was a typo in the original custom datatypes PR. --- include/tvm/runtime/c_runtime_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index ee3542f90255..fd1b877f6d4c 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -115,7 +115,7 @@ typedef enum { kExtReserveEnd = 64U, kExtEnd = 128U, // The rest of the space is used for custom, user-supplied datatypes - kCustomBegin = 128U, + kCustomBegin = 129U, } TVMTypeCode; /*! From 87538e4cd812c118c351ba063910989e9e06b523 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Tue, 28 May 2019 13:35:25 -0700 Subject: [PATCH 054/176] Typo: Tensorflow --> TensorFlow (#3249) --- 3rdparty/bfloat16/bfloat16.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc index 333b534afc08..1f25be17f72d 100644 --- a/3rdparty/bfloat16/bfloat16.cc +++ b/3rdparty/bfloat16/bfloat16.cc @@ -3,7 +3,7 @@ \file tvm/src/codegen/custom_datatypes/mybfloat16.cc \brief Small bfloat16 library for use in unittests - Code originally from Tensorflow; taken and simplified. Original license: + Code originally from TensorFlow; taken and simplified. Original license: Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From f736980ec84b4abdd2174be3a432a9b7a3fc54e1 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Tue, 28 May 2019 15:20:18 -0700 Subject: [PATCH 055/176] [RUST] Rust DSO module (#2976) --- rust/Cargo.toml | 1 + rust/common/build.rs | 27 ++-- rust/runtime/Cargo.toml | 3 + rust/runtime/src/module/dso.rs | 144 ++++++++++++++++++ rust/runtime/src/module/mod.rs | 56 +++++++ .../src/{module.rs => module/syslib.rs} | 35 +---- rust/runtime/src/threading.rs | 2 +- rust/runtime/tests/test_tvm_dso/Cargo.toml | 26 ++++ rust/runtime/tests/test_tvm_dso/build.rs | 42 +++++ .../tests/test_tvm_dso/src/build_test_lib.py | 40 +++++ rust/runtime/tests/test_tvm_dso/src/main.rs | 42 +++++ tests/scripts/task_rust.sh | 4 + 12 files changed, 379 insertions(+), 43 deletions(-) create mode 100644 rust/runtime/src/module/dso.rs create mode 100644 rust/runtime/src/module/mod.rs rename rust/runtime/src/{module.rs => module/syslib.rs} (62%) create mode 100644 rust/runtime/tests/test_tvm_dso/Cargo.toml create mode 100644 rust/runtime/tests/test_tvm_dso/build.rs create mode 100755 rust/runtime/tests/test_tvm_dso/src/build_test_lib.py create mode 100644 rust/runtime/tests/test_tvm_dso/src/main.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 25466e08bdf9..6e89bae5c6f2 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -20,6 +20,7 @@ members = [ "common", "runtime", "runtime/tests/test_tvm_basic", + "runtime/tests/test_tvm_dso", "runtime/tests/test_nnvm", "frontend", "frontend/tests/basics", diff --git a/rust/common/build.rs b/rust/common/build.rs index 5dac99ec54bb..919e0adc46c8 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -22,23 +22,30 @@ extern crate bindgen; use std::path::PathBuf; fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + tvm_home + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm_runtime"); - println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); + println!("cargo:rustc-link-search={}/build", tvm_home); } // @see rust-bindgen#550 for `blacklist_type` bindgen::Builder::default() - .header(format!( - "{}/include/tvm/runtime/c_runtime_api.h", - env!("TVM_HOME") - )) - .header(format!( - "{}/include/tvm/runtime/c_backend_api.h", - env!("TVM_HOME") - )) - .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) .blacklist_type("max_align_t") .layout_tests(false) .derive_partialeq(true) diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index 8e70565a6c13..5809af0c6c6d 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -45,3 +45,6 @@ tvm-common = { version = "0.1.0", path = "../common/" } [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" + +[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] +libloading = "0.5" diff --git a/rust/runtime/src/module/dso.rs b/rust/runtime/src/module/dso.rs new file mode 100644 index 000000000000..3442fad13bf9 --- /dev/null +++ b/rust/runtime/src/module/dso.rs @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + cell::RefCell, + collections::HashMap, + ffi::CStr, + os::raw::{c_char, c_int, c_void}, + pin::Pin, +}; + +use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; + +use crate::{ + threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch}, + workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace}, + TVMAPISetLastError, +}; + +use super::Module; + +const TVM_MAIN: &'static [u8] = b"__tvm_main__"; +const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx"; + +/// A module backed by a Dynamic Shared Object (dylib). +pub struct DsoModule<'a> { + lib: libloading::Library, + packed_funcs: RefCell>, + _pin: std::marker::PhantomPinned, +} + +macro_rules! init_context_func { + ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => { + unsafe { + $( + let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes()); + if let Ok(fn_ptr) = fn_ptr { + **fn_ptr = $fn; + } + )+ + } + }; +} + +impl<'a> DsoModule<'a> { + pub fn new>(filename: P) -> Result>, failure::Error> { + let lib = libloading::Library::new(filename)?; + + init_context_func!( + lib, + (TVMAPISetLastError, extern "C" fn(*const i8)), + ( + TVMBackendAllocWorkspace, + extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void + ), + ( + TVMBackendFreeWorkspace, + extern "C" fn(c_int, c_int, *mut c_void) -> c_int + ), + ( + TVMBackendParallelLaunch, + extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int + ), + ( + TVMBackendParallelBarrier, + extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv) + ), + ); + + // Pin the module in memory so that `ctx` pointer (below) is stable. + let dso_mod = Box::pin(Self { + lib, + packed_funcs: RefCell::new(HashMap::new()), + _pin: std::marker::PhantomPinned, + }); + + unsafe { + if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) { + **ctx = &dso_mod as *const _ as *const c_void; + } + } + + Ok(dso_mod) + } +} + +impl<'a> Module for DsoModule<'a> { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { + let name = name.as_ref(); + let func = match unsafe { + self.lib + .get::(if name.as_bytes() == TVM_MAIN { + // If __tvm_main__ is present, it contains the name of the + // actual main function. + match self + .lib + .get::<*const c_char>(TVM_MAIN) + .map(|p| CStr::from_ptr(*p)) + { + Ok(m) => m.to_bytes(), + _ => return None, + } + } else { + name.as_bytes() + }) + } { + Ok(func) => unsafe { func.into_raw() }, + Err(_) => return None, + }; + + self.packed_funcs.borrow_mut().insert( + name.to_string(), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)), + ); + + self.packed_funcs.borrow().get(name).map(|f| *f) + } +} + +impl<'a> Drop for DsoModule<'a> { + fn drop(&mut self) { + self.packed_funcs + .replace(HashMap::new()) + .into_iter() + .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) }) + .for_each(std::mem::drop); + } +} diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs new file mode 100644 index 000000000000..2c7c107f6b30 --- /dev/null +++ b/rust/runtime/src/module/mod.rs @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +mod dso; +mod syslib; + +use tvm_common::{ + ffi::BackendPackedCFunc, + packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, +}; + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +pub use dso::DsoModule; +pub use syslib::SystemLibModule; + +pub trait Module { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; +} + +// @see `WrapPackedFunc` in `llvm_module.cc`. +fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box { + box move |args: &[TVMArgValue]| { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(tvm_common::errors::FuncCallError::get_with_context( + func_name.clone(), + )) + } + } +} diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module/syslib.rs similarity index 62% rename from rust/runtime/src/module.rs rename to rust/runtime/src/module/syslib.rs index 865338f848fa..227b8c727e8f 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module/syslib.rs @@ -21,14 +21,9 @@ use std::{ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, }; -use tvm_common::{ - ffi::BackendPackedCFunc, - packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, -}; +use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; -pub trait Module { - fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; -} +use super::Module; pub struct SystemLibModule; @@ -53,30 +48,6 @@ impl Default for SystemLibModule { } } -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func( - func_name: String, - func: BackendPackedCFunc, -) -> Box { - box move |args: &[TVMArgValue]| { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(tvm_common::errors::FuncCallError::get_with_context( - func_name.clone(), - )) - } - } -} - #[no_mangle] pub extern "C" fn TVMBackendRegisterSystemLibSymbol( cname: *const c_char, @@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol( let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( name.to_string(), - &*Box::leak(wrap_backend_packed_func(name.to_string(), func)), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)), ); return 0; } diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 96143848f5e2..eb2f418473ed 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv; #[cfg(target_env = "sgx")] use super::{TVMArgValue, TVMRetValue}; -type FTVMParallelLambda = +pub(crate) type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; /// Holds a parallel job request made by a TVM library function. diff --git a/rust/runtime/tests/test_tvm_dso/Cargo.toml b/rust/runtime/tests/test_tvm_dso/Cargo.toml new file mode 100644 index 000000000000..afe7f26e1220 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-tvm-dso" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_tvm_dso/build.rs b/rust/runtime/tests/test_tvm_dso/build.rs new file mode 100644 index 000000000000..f1d9822b01a5 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{env, path::Path, process::Command}; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/test.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); +} diff --git a/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py new file mode 100755 index 000000000000..63b43a5f9bef --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm +from tvm.contrib import cc + +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + obj_file = osp.join(sys.argv[1], 'test.o') + tvm.build(s, [A, B, C], 'llvm').save(obj_file) + cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file]) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_tvm_dso/src/main.rs b/rust/runtime/tests/test_tvm_dso/src/main.rs new file mode 100644 index 000000000000..953676cea5bb --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/src/main.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, DsoModule, Module}; + +fn main() { + tvm_runtime::TVMGetLastError(); + let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap(); + let add = module + .get_function("__tvm_main__") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 1728fece5965..cdf777c86c0e 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -48,6 +48,10 @@ cd tests/test_tvm_basic cargo run cd - +cd tests/test_tvm_dso +cargo run +cd - + # run NNVM graph test cd tests/test_nnvm cargo run From cbd62f1ce5a18dca7e70eff8675715a766bf0600 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 29 May 2019 07:20:58 +0900 Subject: [PATCH 056/176] [TOPI] Fix resize nearest with fractional scaling (#3244) --- nnvm/tests/python/compiler/test_top_level2.py | 2 +- .../python/frontend/coreml/test_forward.py | 2 +- .../python/frontend/onnx/test_forward.py | 2 +- tests/python/frontend/coreml/test_forward.py | 2 +- tests/python/frontend/onnx/test_forward.py | 2 +- tests/python/relay/test_op_level2.py | 2 +- tests/python/relay/test_op_level5.py | 2 +- topi/include/topi/image/resize.h | 21 +++++----------- topi/python/topi/nn/upsampling.py | 1 - topi/python/topi/testing/upsampling_python.py | 16 ++++++++++--- topi/tests/python/test_topi_resize.py | 24 ++++++++++++------- topi/tests/python/test_topi_upsampling.py | 2 +- 12 files changed, 43 insertions(+), 35 deletions(-) diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index b25feb74793f..3c5651578b64 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -305,7 +305,7 @@ def test_upsampling_nearest_neighbor(): data = tvm.nd.array(a_np) m.run(x=data) out = m.get_output(0, tvm.nd.empty(oshape, dtype)) - b_np = topi.testing.upsampling_python(a_np, scale, "NCHW") + b_np = topi.testing.upsampling_python(a_np, (scale, scale), "NCHW") tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) def test_upsampling_bilinear(): diff --git a/nnvm/tests/python/frontend/coreml/test_forward.py b/nnvm/tests/python/frontend/coreml/test_forward.py index 679afe4e86bc..7a9f294f4359 100644 --- a/nnvm/tests/python/frontend/coreml/test_forward.py +++ b/nnvm/tests/python/frontend/coreml/test_forward.py @@ -195,7 +195,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): a_np = np.full(input_dim, 1, dtype=dtype) if mode == 'NN': - b_np = topi.testing.upsampling_python(a_np, scale) + b_np = topi.testing.upsampling_python(a_np, (scale, scale)) else: new_h = input_dim[2] * scale new_w = input_dim[3] * scale diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 941a275a8045..3365b0f25fb1 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -405,7 +405,7 @@ def _test_upsample_nearest(): y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.upsampling_python(in_array, scale, "NCHW") + out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") graph = helper.make_graph([y], 'upsample_nearest_test', diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index da78e960091d..0b6f91bed54f 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -179,7 +179,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): a_np = np.full(input_dim, 1, dtype=dtype) if mode == 'NN': - b_np = topi.testing.upsampling_python(a_np, scale) + b_np = topi.testing.upsampling_python(a_np, (scale, scale)) else: new_h = input_dim[2] * scale new_w = input_dim[3] * scale diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 77f045aa06cc..095f1feb246a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -417,7 +417,7 @@ def _test_upsample_nearest(): y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.upsampling_python(in_array, scale, "NCHW") + out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") graph = helper.make_graph([y], 'upsample_nearest_test', diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index a5350450b0a5..c8f5b1d27a2a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -485,7 +485,7 @@ def get_shape(): func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) if method == "NEAREST_NEIGHBOR": - ref = topi.testing.upsampling_python(data, scale, layout) + ref = topi.testing.upsampling_python(data, (scale, scale), layout) else: ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout) for target, ctx in ctx_list(): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e6d99c765c87..21b227f6b3b5 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -48,7 +48,7 @@ def verify_resize(dshape, scale, method, layout): if method == "BILINEAR": ref_res = topi.testing.bilinear_resize_python(x_data, size, layout) else: - ref_res = topi.testing.upsampling_python(x_data, scale, layout) + ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize(x, size, layout, method, False) assert "size=" in z.astext() diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index fb577a8f06ef..287ff9406618 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -101,15 +101,12 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, out_shape.push_back(shape[1]); out_shape.push_back(input->shape[3]); - Expr h_ratio = shape[0] / input->shape[1]; - Expr w_ratio = shape[1] / input->shape[2]; - return compute( out_shape, [&](const Array& indices) { Array idx; idx.push_back(indices[0]); - idx.push_back(indices[1] / h_ratio); - idx.push_back(indices[2] / w_ratio); + idx.push_back(indices[1] * input->shape[1] / shape[0]); + idx.push_back(indices[2] * input->shape[2] / shape[1]); idx.push_back(indices[3]); return input(idx); @@ -138,16 +135,13 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, out_shape.push_back(shape[0]); out_shape.push_back(shape[1]); - Expr h_ratio = shape[0] / input->shape[2]; - Expr w_ratio = shape[1] / input->shape[3]; - return compute( out_shape, [&](const Array& indices) { Array idx; idx.push_back(indices[0]); idx.push_back(indices[1]); - idx.push_back(indices[2] / h_ratio); - idx.push_back(indices[3] / w_ratio); + idx.push_back(indices[2] * input->shape[2] / shape[0]); + idx.push_back(indices[3] * input->shape[3] / shape[1]); return input(idx); }, name, tag); @@ -176,16 +170,13 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, out_shape.push_back(shape[1]); out_shape.push_back(input->shape[4]); - Expr h_ratio = shape[0] / input->shape[2]; - Expr w_ratio = shape[1] / input->shape[3]; - return compute( out_shape, [&](const Array& indices) { Array idx; idx.push_back(indices[0]); idx.push_back(indices[1]); - idx.push_back(indices[2] / h_ratio); - idx.push_back(indices[3] / w_ratio); + idx.push_back(indices[2] * input->shape[2] / shape[0]); + idx.push_back(indices[3] * input->shape[3] / shape[1]); idx.push_back(indices[4]); return input(idx); diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 14c7c05e00fa..7926df205a56 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -53,5 +53,4 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) else: raise ValueError("not support this layout {} yet".format(layout)) - return topi.cpp.nn.upsampling(data, out_shape, layout, method) diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py index 8ee964010c82..167fdfc7f227 100644 --- a/topi/python/topi/testing/upsampling_python.py +++ b/topi/python/topi/testing/upsampling_python.py @@ -16,25 +16,35 @@ # under the License. # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals """Upsampling in python""" +import math import numpy as np def upsample_nearest(arr, scale): """ Populate the array by scale factor""" - return arr.repeat(scale, axis=0).repeat(scale, axis=1) + h, w = arr.shape + out_h = math.floor(h * scale[0]) + out_w = math.floor(w * scale[1]) + out = np.empty((out_h, out_w)) + for y in range(out_h): + for x in range(out_w): + in_y = math.floor(y / scale[0]) + in_x = math.floor(x / scale[1]) + out[y, x] = arr[in_y, in_x] + return out def upsampling_python(data, scale, layout='NCHW'): """ Python version of scaling using nearest neighbour """ ishape = data.shape if layout == 'NCHW': - oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale) + oshape = (ishape[0], ishape[1], math.floor(ishape[2]*scale[0]), math.floor(ishape[3]*scale[1])) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[1]): output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) return output_np if layout == 'NHWC': - oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3]) + oshape = (ishape[0], math.floor(ishape[1]*scale[0]), math.floor(ishape[1]*scale[1]), ishape[3]) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[3]): diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 26a5e3549de7..82778863af55 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -23,8 +23,7 @@ from common import get_all_backend -def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False): - +def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False, method="BILINEAR"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') dtype = A.dtype @@ -39,9 +38,14 @@ def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, ou raise NotImplementedError( 'Layout not supported {} '.format(layout)) - B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners) + B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) - b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) + if method == "BILINEAR": + b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) + else: + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) def check_device(device): ctx = tvm.context(device, 0) @@ -61,15 +65,19 @@ def check_device(device): for device in get_all_backend(): check_device(device) + def test_resize(): # Scale NCHW - verify_bilinear_scale(4, 16, 32, 32, 50, 50, 'NCHW') + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW') # Scale NCHW + Align Corners - verify_bilinear_scale(6, 32, 64, 64, 20, 20, 'NCHW', True) + verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True) # Scale NHWC - verify_bilinear_scale(4, 16, 32, 32, 50, 50, "NHWC") + verify_resize(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_bilinear_scale(6, 32, 64, 64, 20, 20, "NHWC", True) + verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) + # Nearest + Fractional + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="NEAREST_NEIGHBOR") + verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="NEAREST_NEIGHBOR") if __name__ == "__main__": test_resize() diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 0838f02303f6..ddfb002b6f91 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -46,7 +46,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH out_size = (in_height*scale, in_width*scale) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout) else: - b_np = topi.testing.upsampling_python(a_np, scale, layout) + b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout) def check_device(device): ctx = tvm.context(device, 0) From cb0fe1dea97f86581ea7dadbf8c357a5e5cbec89 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 28 May 2019 18:12:17 -0700 Subject: [PATCH 057/176] [C++] Cleanup transform API nits (#3253) --- include/tvm/relay/transform.h | 107 ++++++++++++++++++++++----------- src/relay/pass/pass_manager.cc | 65 ++++++++------------ 2 files changed, 96 insertions(+), 76 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4d6921a6b860..1c1b60813b78 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -76,8 +76,8 @@ namespace transform { class PassContext; /*! - * \brief PassContextNode contains the information that a pass can rely on, such as - * analysis results. + * \brief PassContextNode contains the information that a pass can rely on, + * such as analysis results. */ class PassContextNode : public RelayNode { public: @@ -110,32 +110,51 @@ class PassContextNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); }; +/*! + * \brief PassContext that is used to configure the pass behavior. + * + * \code + * + * auto new_ctx = PassContext::Create(); + * ctx->opt_level = 2; + * ctx->fallback_device = kDLCPU; + * With scope(ctx); + * // pass context in effect. + * + * \endcode + */ class PassContext : public NodeRef { public: PassContext() {} - explicit PassContext(tvm::NodePtr n) : NodeRef(n) {} - - /* - * \brief Constructor of a `PassContext` object. - * - * \param opt_level The optimization level that will be applied. - * \param fallback_device The fallback device used for heterogeneous - * execution. - * \param required_pass The passes that are required for a context to execute - * other passes. - * \param required_pass The passes that will be disabled during the - * optimization under a context. + explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {} + /*! + * \brief const accessor. + * \return const access pointer. + */ + const PassContextNode* operator->() const { + CHECK(node_.get() != nullptr); + return static_cast(node_.get()); + } + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + PassContextNode* operator->() { + CHECK(node_.get() != nullptr); + return static_cast(node_.get()); + } + /*! + * \brief Construct a PassContext containing the default configurations. + * \return The new PassContext. + */ + TVM_DLL static PassContext Create(); + /*! + * \brief Get the default pass context in the current scope. + * \return The pass context. */ - TVM_DLL PassContext(int opt_level, - int fallback_device, - tvm::Array required_pass, - tvm::Array disabled_pass); - - // Get the currently used pass context. TVM_DLL static PassContext Current(); - const PassContextNode* operator->() const; - + // accessor. using ContainerType = PassContextNode; class Internal; @@ -204,25 +223,23 @@ class PassNode : public RelayNode { virtual PassInfo Info() const = 0; /*! - * \brief Execute the optimization pass using a functor. This functor - * internally uses a current pass context. + * \brief Transform mod using the default PassContext in the current scope. * * \param mod The module that an optimization pass runs on. * - * \return The updated module. + * \return The transformed module. */ Module operator()(const Module& mod) const { return this->operator()(mod, PassContext::Current()); } /*! - * \brief Execute the optimization pass using a functor under a given pass context. + * \brief Transform mod using a functor under a given pass context. * * \param mod The module that an optimization pass runs on. - * \param pass_ctx The pass context that will be used to help the execution of - * optimizations. + * \param pass_ctx The pass context that can provide information for the optimization. * - * \return The updated module. + * \return The transformed module. */ virtual Module operator()(const Module& mod, const PassContext& pass_ctx) const = 0; @@ -235,14 +252,34 @@ class PassNode : public RelayNode { class Pass : public NodeRef { public: - Pass() = default; - explicit Pass(NodePtr p) : NodeRef(p) {} - - PassNode* operator->() const { - return static_cast(this->node_.get()); + /*! + * \brief Transform mod using the default PassContext in the current scope. + * + * \param mod The module that an optimization pass runs on. + * + * \return The transformed module. + */ + Module operator()(const Module& mod) const { + const PassNode* node = operator->(); + CHECK(node != nullptr); + return node->operator()(mod); + } + /*! + * \brief Transform mod using a functor under a given pass context. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ + Module operator()(const Module& mod, + const PassContext& pass_ctx) const { + const PassNode* node = operator->(); + CHECK(node != nullptr); + return node->operator()(mod, pass_ctx); } - using ContainerType = PassNode; + TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode); }; class SequentialNode; diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index ea4c976b7db5..a9c671aa163a 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -74,21 +74,6 @@ class OptPassLevel { } }; -PassContext::PassContext(int opt_level, int fallback_device, - tvm::Array required_pass, - tvm::Array disabled_pass) { - auto ctx = make_node(); - ctx->opt_level = opt_level; - ctx->fallback_device = fallback_device; - ctx->required_pass = std::move(required_pass); - ctx->disabled_pass = std::move(disabled_pass); - node_ = std::move(ctx); -} - -const PassContextNode* PassContext::operator->() const { - return static_cast(node_.get()); -} - struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; @@ -129,6 +114,10 @@ PassContext PassContext::Current() { } } +PassContext PassContext::Create() { + return PassContext(make_node()); +} + class ModulePass; /*! @@ -291,7 +280,7 @@ class SequentialNode : public PassNode { * * \return true if the pass is enabled. Otherwise, false. */ - bool pass_enabled(const std::string& pass_name) const; + bool PassEnabled(const std::string& pass_name) const; /*! * \brief Resolve the pass dependency. It globs all required passes by @@ -353,9 +342,8 @@ ModulePass ModulePassNode::make( Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); - LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name - << " with opt level: " << pass_info.operator->()->opt_level << "\n"; - + DLOG(INFO) << "Executing module pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level << "\n"; CHECK(mod.defined()); auto updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); @@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make( Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); - LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name - << " with opt level: " << pass_info.operator->()->opt_level << "\n"; CHECK(mod.defined()); Module new_mod = ModuleNode::make({}, mod->type_definitions); - + DLOG(INFO) << "Executing module pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level << "\n"; // Execute the pass function and return a new module. for (const auto& it : mod->functions) { auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx); @@ -448,12 +435,11 @@ std::unordered_set SequentialNode::RequiredPasses( return ret; } -bool SequentialNode::pass_enabled(const std::string& pass_name) const { +bool SequentialNode::PassEnabled(const std::string& pass_name) const { PassContext ctx = PassContext::Current(); - const PassContextNode* ctx_node = ctx.operator->(); - auto required = RequiredPasses(ctx_node->required_pass); - auto disabled = DisabledPasses(ctx_node->required_pass); + auto required = RequiredPasses(ctx->required_pass); + auto disabled = DisabledPasses(ctx->required_pass); if (disabled.count(pass_name)) { return false; @@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const { if (required.count(pass_name)) { return true; } - return ctx_node->opt_level >= opt_pass_level[pass_name]; + return ctx->opt_level >= opt_pass_level[pass_name]; } // TODO(zhiics): we currenlty only sequentially execute each pass in @@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const { // ordering problem needed to be handled in the future. Module SequentialNode::operator()(const Module& module, const PassContext& pass_ctx) const { - const auto* ctx_node = pass_ctx.operator->(); - int opt_level = ctx_node->opt_level; - auto disabled = DisabledPasses(ctx_node->disabled_pass); + int opt_level = pass_ctx->opt_level; + auto disabled = DisabledPasses(pass_ctx->disabled_pass); Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; PassInfo info = pass->Info(); - const auto& pass_name = info.operator->()->name; - const auto& pass_opt_level = info.operator->()->opt_level; + const auto& pass_name = info->name; + const auto& pass_opt_level = info->opt_level; // Skip the pass if its optimization level is higher that the one of in the // pass context or if this pass is disabled. if (pass_opt_level > opt_level || disabled.count(pass_name)) { @@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass") TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { - Pass pass = args[0]; - Module mod = args[1]; - CHECK(pass.defined()) - << "Running an undefined pass is not allowed." - << "\n"; - - const auto* pn = pass.operator->(); - *ret = (*pn)(mod); + *ret = args[0].operator Pass()(args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_API("relay._transform.PassContext") .set_body([](TVMArgs args, TVMRetValue* ret) { + auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; tvm::Array required = args[2]; tvm::Array disabled = args[3]; - *ret = PassContext(opt_level, fallback_device, required, disabled); + pctx->opt_level = opt_level; + pctx->fallback_device = fallback_device; + pctx->required_pass = std::move(required); + pctx->disabled_pass = std::move(disabled); + *ret = pctx; }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) From 905367718b0dfde9923439c1052765fdedfeae20 Mon Sep 17 00:00:00 2001 From: Hua Date: Wed, 29 May 2019 10:32:47 -0700 Subject: [PATCH 058/176] [BugFix][VTA] Fix vta_conv2d crash issue after change vta_config.json configuration. (#3213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: Once change LOG_BLOCK_IN or LOG_BLOCK_OUT into > 4 value, when run vta “Simple Matrix Multiply” or load vta, vta would crash at vta_conv2d.py. Analysis: This issue caused by resnet18 logic of vta_conv2d.py which have in_filter minmum size that is 16. > 4 value would cause such in_filter check failed then make xfer_size be empty and find_schedules function return a empty list finally cause crash. Solution: add the empty list check. --- vta/python/vta/top/vta_conv2d.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 9d0d3dfe3d72..d685379f136c 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -165,7 +165,7 @@ def _get_data_movement_byte(schedule, layer): fil_sched.append(schedule) xfer_size.append(_get_data_movement_byte(schedule, layer)) - if best_only: + if best_only and xfer_size: return [fil_sched[xfer_size.index(min(xfer_size))]] return fil_sched @@ -515,5 +515,10 @@ def __str__(self): } for idx in RESNET: - scheds = find_schedules(RESNET[idx], vt_only=True, best_only=True)[0] - _WL2PLAN[RESNET[idx]] = scheds + f_schedules = find_schedules(RESNET[idx], vt_only=True, best_only=True) + if f_schedules: + scheds = f_schedules[0] + _WL2PLAN[RESNET[idx]] = scheds + else: + logging.warning("No valid schedule was found for the workload on current vta configuration") + break From 4cd7589a1e19197cf516c249a32f635fe91c5f68 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 29 May 2019 16:36:05 -0700 Subject: [PATCH 059/176] [AutoTVM]Core functionality for Graph tuner (#2184) * Add graph tuning * Add tests * Fix tests * Fix pylint * Small fix for docstring * Minor fix * Support fetching workload from relay expr * Simplify benchmark layout transformation * Add relay support * Fix infer layout func name * Refactor internal data representation * Fix issues * Add PBQP solver * Fix layout transform check * Add PBQPTuner test * Fix lint * Update tutorial * Fix tutorial * Fix lint * Add relay test * Remove nnvm since nnvm graph can be converted to relay function * Modify benchmark layout wrt new layout_transform api * Fix lint * Update docstring for DP tuner * Refactor traverse graph * Support graph tuning for multiple target operators * Fix fetching workloads * Add x86 depthwise_conv2d infer_layout * Fix x86 depthwise_conv2d autotvm * Fix PBQP tuner * Fix DP tuner * Generate dummy layout transform record * Update tutorial * Modify layout records name * Add ASF header * Add ASF header for testing files * Fix test * Fix topi fetching * Some refactors * Fix lint * Fix tutorial * Rename test files * Fix doc typo * Add test case note link --- python/tvm/autotvm/graph_tuner/__init__.py | 25 + python/tvm/autotvm/graph_tuner/_base.py | 27 + .../autotvm/graph_tuner/base_graph_tuner.py | 522 ++++++++++++++++++ .../graph_tuner/dynamic_programming_stage.py | 358 ++++++++++++ .../graph_tuner/dynamic_programming_tuner.py | 189 +++++++ python/tvm/autotvm/graph_tuner/pbqp_tuner.py | 288 ++++++++++ .../tvm/autotvm/graph_tuner/utils/__init__.py | 26 + .../graph_tuner/utils/traverse_graph.py | 312 +++++++++++ python/tvm/autotvm/graph_tuner/utils/utils.py | 110 ++++ python/tvm/autotvm/task/__init__.py | 3 +- python/tvm/autotvm/task/topi_integration.py | 19 +- .../python/unittest/test_graph_tuner_core.py | 254 +++++++++ .../python/unittest/test_graph_tuner_utils.py | 149 +++++ topi/python/topi/nn/conv2d.py | 20 + topi/python/topi/nn/depthwise_conv2d.py | 19 + topi/python/topi/x86/conv2d.py | 17 +- topi/python/topi/x86/depthwise_conv2d.py | 20 +- tutorials/autotvm/tune_relay_x86.py | 17 +- 18 files changed, 2364 insertions(+), 11 deletions(-) create mode 100644 python/tvm/autotvm/graph_tuner/__init__.py create mode 100644 python/tvm/autotvm/graph_tuner/_base.py create mode 100644 python/tvm/autotvm/graph_tuner/base_graph_tuner.py create mode 100644 python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py create mode 100644 python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py create mode 100644 python/tvm/autotvm/graph_tuner/pbqp_tuner.py create mode 100644 python/tvm/autotvm/graph_tuner/utils/__init__.py create mode 100644 python/tvm/autotvm/graph_tuner/utils/traverse_graph.py create mode 100644 python/tvm/autotvm/graph_tuner/utils/utils.py create mode 100644 tests/python/unittest/test_graph_tuner_core.py create mode 100644 tests/python/unittest/test_graph_tuner_utils.py diff --git a/python/tvm/autotvm/graph_tuner/__init__.py b/python/tvm/autotvm/graph_tuner/__init__.py new file mode 100644 index 000000000000..d590db0e7c48 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Autotvm graph tuner API.""" +from __future__ import absolute_import as _abs + +from . import _base +from . import base_graph_tuner + +from .base_graph_tuner import BaseGraphTuner +from .dynamic_programming_tuner import DPTuner +from .pbqp_tuner import PBQPTuner diff --git a/python/tvm/autotvm/graph_tuner/_base.py b/python/tvm/autotvm/graph_tuner/_base.py new file mode 100644 index 000000000000..83b9e06ba564 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/_base.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Helper functions and global data""" + + +RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape", + "multibox_prior", "multibox_transform_loc", "where", + "non_max_suppression", "strided_slice"] + +# We set a large time to represent an invalid layout-transformation. +# This number is set to be 10e9 seconds to align with autotvm. +INVALID_LAYOUT_TIME = 10e9 diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py new file mode 100644 index 000000000000..0fbfc27310cb --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -0,0 +1,522 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-instance-attributes,too-many-branches,too-many-nested-blocks,invalid-name,unused-argument,unused-variable,no-member,no-value-for-parameter +"""Base class for graph tuner.""" +import logging +from abc import abstractmethod + +import numpy as np +import topi + +import tvm +from tvm import autotvm, relay +from tvm.autotvm.task import get_config +from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args +from tvm.autotvm.record import encode, load_from_file +from tvm.autotvm.measure import MeasureResult, MeasureInput + +from ... import target as _target +from .utils import is_input_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \ + bind_inputs, expr2graph +from ._base import INVALID_LAYOUT_TIME + + +# Setup topi_op_name -> layout function +# NOTE: To add more ops, change the following dictionary. +OP2LAYOUT = { + "topi_nn_conv2d": topi.nn.conv2d_infer_layout, + "topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout, +} + + +@autotvm.template +def layout_transform(*args): + """Autotvm layout transform template.""" + args = deserialize_args(args) + cfg = get_config() + cfg.add_flop(-1) + data = args[0] + out = topi.layout_transform(*args) + sch = topi.generic.schedule_injective([out]) + return sch, [data, out] + + +class BaseGraphTuner(object): + """Class to search schedules considering both kernel execution time and + layout transformation time. + + Before creating a Graph Executor instance, schedule candidates for all kernels in + graph should be provided through tensor-level tuning. + """ + def __init__(self, graph, input_shapes, records, target_ops, + target, max_sch_num=20, dtype="float32", verbose=True, + log_file="graph_tuner.log", log_level=logging.DEBUG, + name="graph_tuner"): + """Create a GlobalTuner instance. Local schedule searching for all nodes with + target_op in the input graph and layout transformation benchmark need to be + executed before initialization. + + graph : tvm.relay.Expr.Function + Input graph + + input_shapes : dict of str to tuple. + Input shapes of graph + + records : str or iterator of (MeasureInput, MeasureResult) + Collection of kernel level tuning records. + If it is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + + target_ops : List of str + Target tuning operators. + + target : str or tvm.target + Compilation target. + + max_sch_num : int, optional + Maximum number of schedule candidates for each workload. + + dtype : str, optional + Data type. + + log_file : str, optional + graph tuner log file name + + name : str, optional + Name of global tuner. + """ + self._node_list = [] + self._layout_transform_perf_records = {} + self._layout_transform_interlayer_cost = {} + self._input_shapes = input_shapes + self._target_ops = [op.__name__ for op in target_ops] + + self._name = name + self._max_sch_num = max_sch_num + self._optimal_sch_dict = {} + self._records = records + self._dtype = dtype + if isinstance(target, str): + target = _target.create(target) + self._target = target + self._optimal_record_dict = {} + + # Set up logger + self._verbose = verbose + self._logger = logging.getLogger(name + "_logger") + need_file_handler = need_console_handler = True + for handler in self._logger.handlers: + if handler.__class__.__name__ == 'FileHandler': + need_file_handler = False + if handler.__class__.__name__ == 'StreamHandler': + need_console_handler = False + self._log_level = log_level + self._log_file = log_file + self._formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + self._logger.setLevel(log_level) + if need_file_handler: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(self._formatter) + self._logger.addHandler(file_handler) + if self._verbose and need_console_handler: + console_handler = logging.StreamHandler() + console_handler.setFormatter(self._formatter) + self._logger.addHandler(console_handler) + self._logger.setLevel(log_level) + self._logger.propagate = False + + # Generate workload and schedule dictionaries. + if isinstance(graph, relay.expr.Function): + node_dict = {} + graph = bind_inputs(graph, input_shapes, dtype) + expr2graph(graph, self._target_ops, node_dict, self._node_list) + else: + raise RuntimeError("Unsupported graph type: %s" % str(type(graph))) + + self._graph = graph + self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys()) + self._out_nodes_dict = get_out_nodes(self._in_nodes_dict) + self._fetch_cfg() + + # Setup infer_layout for elemwise-like nodes + # Note: graph tuner currently only supports tuning of single input and single output + # op as target op, such as conv2d, dense and conv2d_transpose. In this case, we can + # reuse infer_layout function from target ops for elemwise-like nodes. The behavior + # is to modify the first tensor shape of input workload to the output shape of + # elemwise-like node, and use infer_layout function from input op to generate layouts. + input_names = self._input_shapes.keys() + for idx in sorted(self._in_nodes_dict.keys()): + if has_multiple_inputs(self._node_list, idx, input_names): + node_entry = self._node_list[idx] + node_entry["topi_op"] = [] + node_entry["workloads"] = [] + for input_idx in self._in_nodes_dict[idx]: + input_node = self._node_list[input_idx] + if not is_input_node(input_node, input_names): + input_topi_op = input_node["topi_op"][0] + node_entry["topi_op"].append(input_topi_op) + # Only replace the first input tensor + input_workload = input_node["workloads"][0] + first_tensor = input_workload[1] + dtype = first_tensor[-1] + new_shape = tuple([val.value for val in node_entry["types"][0].shape]) + actual_workload = (input_workload[0],) + \ + ((new_shape + (dtype,)),) + input_workload[2:] + node_entry["workloads"].append(actual_workload) + if "record_candidates" not in node_entry: + node_entry["record_candidates"] = input_node["record_candidates"] + else: + node_entry["topi_op"].append(None) + node_entry["workloads"].append(None) + + + def _fetch_cfg(self): + """Read and pre-process input schedules.""" + if isinstance(self._records, str): + records = load_from_file(self._records) + else: + records = self._records + cfg_dict = {} + for record in records: + in_measure, _ = record + workload = in_measure.task.workload + if workload not in cfg_dict: + cfg_dict[workload] = [] + cfg_dict[workload].append(record) + + cache_dict = {} + for key in self._in_nodes_dict: + node_entry = self._node_list[key] + if node_entry["op"] not in self._target_ops: + continue + workload = node_entry["workloads"][0] + if workload in cache_dict: + node_entry["record_candidates"] = cache_dict[workload] + continue + record_candidates = [] + infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] + layout_tracking_dict = {} + for record in cfg_dict[workload]: + in_measure, out_measure = record + workload = in_measure.task.workload + cfg = in_measure.config + # For multiple cfgs which produces the same in/out layouts, + # only the most efficient one is preserved. + with self._target: + layouts = infer_layout_func(workload, cfg) + if layouts in layout_tracking_dict: + cost = out_measure.costs[0] + current_best_cost = layout_tracking_dict[layouts][1].costs[0] + if cost < current_best_cost: + layout_tracking_dict[layouts] = record + else: + layout_tracking_dict[layouts] = record + sorted_records = sorted(layout_tracking_dict.values(), + key=lambda item: item[1].costs[0]) + for i in range(min(self._max_sch_num, len(sorted_records))): + record_candidates.append(sorted_records[i]) + node_entry["record_candidates"] = record_candidates + cache_dict[workload] = record_candidates + + def _iterate_layout_transform(self, callback): + """Iterate all possible layout transformations and execute callback for each + iteration. callback function accepts 6 arguments: from_node_idx, to_node_idx, + from_sch_idx, to_sch_idx, args which represent the argument list of layout + transformation and is_valid showing whether this is a valid layout transformation. + """ + input_names = self._input_shapes.keys() + for key, val in self._in_nodes_dict.items(): + node_entry = self._node_list[key] + target_input_idx = -1 + target_input_pos = -1 + if has_multiple_inputs(self._node_list, key, input_names): + for i, item in enumerate(val): + if not is_input_node(self._node_list[item], input_names): + target_input_idx = item + target_input_pos = i + break + + for i, item in enumerate(val): + i_idx = item + in_node_entry = self._node_list[i_idx] + if is_input_node(in_node_entry, input_names): + continue + + if node_entry["op"] in self._target_ops: + o_idx = key + o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] + o_wkl = node_entry["workloads"][0] + i_topi_op = in_node_entry["topi_op"][0] + i_wkl = in_node_entry["workloads"][0] + pivot = 0 + while not i_wkl: + pivot += 1 + i_topi_op = in_node_entry["topi_op"][pivot] + i_wkl = in_node_entry["workloads"][pivot] + i_infer_layout_func = OP2LAYOUT[i_topi_op] + else: + o_idx = target_input_idx + if i <= target_input_pos: + continue + o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] + o_wkl = node_entry["workloads"][target_input_pos] + i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]] + i_wkl = node_entry["workloads"][i] + + + for m, i_record in enumerate(in_node_entry["record_candidates"]): + for n, o_record in enumerate(node_entry["record_candidates"]): + i_cfg, o_cfg = i_record[0].config, o_record[0].config + with self._target: + i_input_info, i_output_info = i_infer_layout_func(i_wkl, i_cfg) + o_input_info, o_output_info = o_infer_layout_func(o_wkl, o_cfg) + if len(i_input_info) > 1 or len(i_output_info) > 1 or \ + len(o_input_info) > 1 or len(o_output_info) > 1: + raise RuntimeError("Graph tuner only supports target operator " + "with single input and single output. " + "Please check target_ops argument.") + + in_shape, in_layout = i_output_info[0] + if node_entry["op"] in self._target_ops: + _, out_layout = o_input_info[0] + else: + _, out_layout = o_output_info[0] + data_placeholder = tvm.placeholder(in_shape, name="data", + dtype=self._dtype) + args = [data_placeholder, in_layout, out_layout] + callback(i_idx, o_idx, m, n, args) + + + def _create_matrix_callback(self, from_node_idx, to_node_idx, from_sch_idx, + to_sch_idx, args): + """Create dictionary containing matrix format of layout transformation + between nodes.""" + sargs = serialize_args(args) + in_layout, out_layout = args[1], args[2] + ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs) + idx_pair_key = (from_node_idx, to_node_idx) + + if in_layout == out_layout: + layout_transform_time = 0 + else: + layout_transform_time = \ + self._layout_transform_perf_records[ltf_workload][1].costs[0] + + if idx_pair_key not in self._layout_transform_interlayer_cost: + self._layout_transform_interlayer_cost[idx_pair_key] = [] + if len(self._layout_transform_interlayer_cost[idx_pair_key]) <= from_sch_idx: + self._layout_transform_interlayer_cost[idx_pair_key].append([]) + self._layout_transform_interlayer_cost[idx_pair_key][from_sch_idx]\ + .append(layout_transform_time) + + def benchmark_layout_transform(self, min_exec_num=100, timeout=10, + use_rpc=False, device_key=None, host="localhost", + port=9190, n_parallel=1, build_func='default', + layout_records=None, target_host=None, infer_layout=False): + """Benchmark all possible layout transformation in the graph, + given a set of schedule candidates for each workload of target operator. + + Parameters + ---------- + min_exec_num : int, optional + Minimum number of execution. Final execution time is the average of + all execution time. + + timeout : int, optional + Time out for each execution. + + use_rpc : boolean, optional + Whether to use rpc mode for benchmarking. + + device_key : str, optional + Remote device key which can be queried by + python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190 + + host : str, optional + IP address used to create RPC tracker on host machine. + + port : int, optional + Port number used to create RPC tracker on host machine. + + n_parallel: int, optional + The number of measurement task that can run in parallel. + Set this according to the number of cpu cores (for compilation) and + the number of devices you have (for measuring generate code). + + build_func: str or callable, optional + 'default': call default builder. This works for normal target (llvm, cuda) + + 'ndk': use Android NDK to create shared library. Use this for android target. + + callable: customized build function for other backends (e.g. VTA). + See autotvm/measure/measure_methods.py::default_build_func for example. + + layout_records : str or iterator of (MeasureInput, MeasureResult). optional + Collection of layout_transform benchmarking records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + + If this argument is set, graph tuner will first check whether layout_transform + workload already exists in records and skip benchmarking if possible. + + target_host : str, optional + str or :any:`tvm.target.Target` optional + Host compilation target, if target is device. + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + setup the dimensions and parameters correctly. + target_host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + + infer_layout : bool, optional + Whether to infer layout transformation time if it doesn't exist in records, instead + of benchmarking on target device. + + This might bring performance loss comparing to benchmarking layout transformation. + """ + self._logger.info("Start to benchmark layout transformation...") + if layout_records is None and infer_layout: + raise RuntimeError("Requires some records to infer layout transformation time.") + + if isinstance(layout_records, str): + layout_records = load_from_file(layout_records) + if not layout_records and infer_layout: + raise RuntimeError("Records must be non-empty to infer layout transformation time.") + + if isinstance(layout_records, str): + layout_records = load_from_file(layout_records) + num_flops, total_time = 0, 0 + if layout_records is not None: + for record in layout_records: + ltf_wkl = record[0].task.workload + self._layout_transform_perf_records[ltf_wkl] = record + input_shape = ltf_wkl[1][1] + flops = np.prod(input_shape) + num_flops += flops + total_time += record[1].costs[0] + avg_time = total_time / num_flops if num_flops > 0 else 0 + + args_list = [] + def _fetch_args_callback(from_node_idx, to_node_idx, from_sch_idx, + to_sch_idx, args): + """Callback function to fetch layout transform args""" + _, in_layout, out_layout = args + if in_layout != out_layout: + args_list.append(args) + + self._iterate_layout_transform(_fetch_args_callback) + + def _log_to_list(record_list): + """Callback to log result to a list.""" + def _callback(_, inputs, results): + """Callback implementation""" + record_list.append((inputs[0], results[0])) + return _callback + + builder = autotvm.LocalBuilder(n_parallel=n_parallel, build_func=build_func) + runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout) + if use_rpc: + if device_key is None: + raise RuntimeError("device_key need to be set to use rpc tracker mode.") + runner = autotvm.measure.RPCRunner(device_key, host, port, n_parallel=n_parallel, + number=min_exec_num, repeat=1, + timeout=timeout) + measure_option = autotvm.measure_option(builder=builder, runner=runner) + for args in args_list: + args = serialize_args(args) + ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args) + if ltf_workload in self._layout_transform_perf_records: + continue + + if infer_layout: + input_shape = ltf_workload[1][1] + flops = 1 + for i in input_shape: + flops *= i + inferred_time = flops * avg_time + record_input = MeasureInput(target=self._target, task=None, config=None) + record_output = MeasureResult(costs=(inferred_time,), error_no=0, + all_cost=-1, timestamp=-1) + self._layout_transform_perf_records[ltf_workload] = (record_input, record_output) + continue + + records = [] + task = autotvm.task.create(layout_transform, args=args, target=self._target, + target_host=target_host) + task.workload = ltf_workload + tuner = autotvm.tuner.GridSearchTuner(task) + tuner.tune(n_trial=1, measure_option=measure_option, + callbacks=[_log_to_list(records)]) + if not isinstance(records[0][1].costs[0], float): + records[0] = (records[0][0], records[0][1]._replace(costs=(INVALID_LAYOUT_TIME,))) + self._layout_transform_perf_records[ltf_workload] = records[0] + + self._iterate_layout_transform(self._create_matrix_callback) + self._logger.info("Benchmarking layout transformation successful.") + + @property + def layout_transform_perf_records(self): + """Get layout transformation dictionary for input graph. + + Returns + ------- + layout_transform_perf_records : dict of tuple to (MeasureInput, MeasureResult) + Layout transformation dictionary for input graph. + """ + return self._layout_transform_perf_records + + + def get_optimal_records(self): + """Convert optimal record dictionary to a list of records + with ascending order of node index in graph. + + Returns + ------- + sch_list : list of tuple + List of records with ascending order of node index in graph. + """ + ordered_index_list = sorted(self._optimal_record_dict.keys()) + ret = [] + for index in ordered_index_list: + node_entry = self._node_list[index] + if node_entry["op"] not in self._target_ops: + continue + ret.append(node_entry["record_candidates"][self._optimal_record_dict[index]]) + return ret + + def write_opt_sch2record_file(self, record_file="graph_opt_schedule.log"): + """Write graph level optimal schedules into file. + + Parameters + ---------- + record_file : str, optional + Output schedule file. + """ + with open(record_file, "a") as out_file: + records = self.get_optimal_records() + for record in records: + out_file.write(encode(record[0], record[1]) + "\n") + msg = "Writing optimal schedules to %s successfully." % record_file + self._logger.info(msg) + + @abstractmethod + def run(self, **kwargs): + """Run graph tuning.""" + pass diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py new file mode 100644 index 000000000000..4a512c224a1d --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py @@ -0,0 +1,358 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-instance-attributes,too-many-branches,too-many-statements,too-many-arguments,too-many-locals,invalid-name +"""Stage class for dynamic programming tuner""" +import numpy as np + +from .utils import is_input_node + + +class DPStage(object): + """Class to represent node in Markov decision process. A stage has states + to represent different schedules of the current node. Since in this problem + the action is the schedule selected for current node, action can be fully + represented by states. No extra attribute needs for action. + + In most cases, instance of this class should be created through DPTuner. + """ + def __init__(self, idx, input_shapes, node_list, + counted_nodes_set, layout_transform_interlayer_cost, + stage_dict, in_nodes_dict, out_nodes_dict, + dep_dict, target_ops, dtype="float32"): + """Initialize a stage and create all states. + + Parameters + ---------- + idx : int + Index for current node. + + input_shapes : dict of string to tuple of int + Input shapes for current graph. + + node_list : list of dict + List of all nodes for current graph. + + counted_nodes_set : set of int + Global set recording whether the execution time of a node has been counted. + + layout_transform_interlayer_cost : dict of tuple to list + Dictionary maps node index pair to layout transformation time between them. + + stage_dict : dict of int to Stage + Global dictionary for all stages mapping node index to stage. + + in_nodes_dict : dict of int to list of int + Dictionary maps node index to corresponding input node index. + + out_nodes_dict : dict of int to list of int + Dictionary maps node index to corresponding output node index. + + dep_dict : dict of int to set of int + Dictionary maps node index to dependent node index. + + target_ops : list of str + Target operators + + dtype : str, optional + Data type. + """ + self._global_input_shapes = input_shapes + self._global_input_names = input_shapes.keys() + self._global_node_list = node_list + self._global_counted_nodes_set = counted_nodes_set + self._global_layout_transform_interlayer_cost = layout_transform_interlayer_cost + self._global_stage_dict = stage_dict + self._global_in_nodes_dict = in_nodes_dict + self._global_out_nodes_dict = out_nodes_dict + self._global_dep_dict = dep_dict + + self._idx = idx + self._node_entry = self._global_node_list[idx] + self._target_ops = target_ops + self._wkl = self._node_entry["workloads"][0] + self._record_list = self._node_entry["record_candidates"] + self._dep = [] + self._dtype = dtype + self._states = None + self._full_states = None + self._full_states_idx = None + self._create_states() + + def _create_states(self): + """Create states.""" + node = self._global_node_list[self._idx] + if node["op"] in self._target_ops: + self._create_op_states() + else: + self._create_multi_inputs_states() + + def _create_op_states(self): + """State creation routine for nodes with target_op.""" + input_idx = -1 + for index in self._global_in_nodes_dict[self._idx]: + input_idx = index + if not is_input_node(self._global_node_list[input_idx], + self._global_input_names): + break + + if is_input_node(self._global_node_list[input_idx], + self._global_input_names): + self._full_states = np.array([record[1].costs[0] + for record in self._record_list]) + self._states = self._full_states + else: + input_node_entry = self._global_node_list[input_idx] + input_stage = self._global_stage_dict[input_idx] + input_dep = input_stage.dep + input_states = input_stage.states + input_flatten_states = input_states.flatten() + input_record_list = input_node_entry["record_candidates"] + num_schedules = len(self._record_list) + num_input_schedules = len(input_record_list) + num_input_states = input_flatten_states.shape[0] + + full_states_shape = tuple([num_schedules, num_input_schedules] + + [len(self._global_node_list[dep_idx]["record_candidates"]) + for dep_idx in input_dep]) + self._full_states = np.zeros(full_states_shape).flatten().astype("float32") + self._full_states_idx = [self._idx, input_idx] + input_dep + dep_multiplier = 1 + for i in range(2, len(full_states_shape)): + dep_multiplier *= full_states_shape[i] + input_node_time_counted = input_idx in self._global_counted_nodes_set + + for i in range(num_schedules): + current_sch_time = float(self._record_list[i][1].costs[0]) + for j in range(num_input_states): + input_sch_idx = j // dep_multiplier + layout_transform_time = \ + self._global_layout_transform_interlayer_cost \ + [(input_idx, self._idx)][input_sch_idx][i] + + if input_node_time_counted: + total_time = current_sch_time + layout_transform_time + else: + total_time = \ + current_sch_time + layout_transform_time + input_flatten_states[j] + current_state_idx = i * num_input_states + j + self._full_states[current_state_idx] = total_time + + if not input_node_time_counted: + self._global_counted_nodes_set.add(input_idx) + self._full_states = self._full_states.reshape(full_states_shape) + + # If out degree of input node is 1, we can remove the dimension of input node, + # since the states of input node will not be needed any more. Otherwise, input + # node should become a dependency. + if len(self._global_out_nodes_dict[input_idx]) == 1: + self._states = np.amin(self._full_states, axis=1) + self._dep = list(input_dep) + else: + self._states = self._full_states + self._dep = [input_idx,] + input_dep + + # Update global dependency dictionary. + # This is to monitor the dependency states to decide + # when a dependency can be eliminated, so that total + # number of states can be largely reduced. + for dep_idx in self._dep: + self._global_dep_dict[dep_idx].remove(self._idx) + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[dep_idx].add(child) + if len(self._global_out_nodes_dict[self._idx]) > 1: + self._global_dep_dict[self._idx] = set() + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[self._idx].add(child) + + def _create_multi_inputs_states(self): + """State creation routine for multi_input operator + + In tvm, layout transformation for an elemwise-like follow the rule which + all input operators transform their layouts to the leftmost input operator + layout. For example: + elemwise-sum + | | | + | | | + op0 op1 op2 + In this block, the possible layout transformations are: op1 -> op0 and op2 -> op0. + In graph tuning, a 3-D array with shape (k0, k1, k2) can represent the layout + transformations between these three nodes. It is also possible some earlier states + belong to other nodes(We name them as dependency) are required for dynamic programming. + The final states array for this elemwise-sum can be with shape (e0, k0, k1, e1, k2). + To iterate through all states, we first align the shape of op0, op1 and op2 to be + (e0, k0, k1, e1, k2) by broadcasting the original states. We also record the axis of + each input node in the states array, together with the multiplier. For example, + the axis index for op0 is 1, and multiplier is k1 * e1 * k2. If current iterating index + in the flatten array is i, the index of op0 can be computed as: + i % (k0 * k1 * e1 * k2) // (k1 * e1 * k2). + """ + full_input_node_list = list(self._global_in_nodes_dict[self._idx]) + input_index_list = [] + # Remove input and parameter nodes + for input_idx in full_input_node_list: + if not is_input_node(self._global_node_list[input_idx], + self._global_input_names): + input_index_list.append(input_idx) + + # Generate new states + states_list, aligned_node_list = DPStage.align_states(input_index_list, + self._global_stage_dict, + self._global_node_list) + target_node_idx, target_major_axis, target_multiplier, target_states = states_list[0] + aligned_shape = target_states.shape + self._full_states = np.zeros(aligned_shape).astype("float32").flatten() + self._full_states_idx = list(aligned_node_list) + num_states = self._full_states.shape[0] + node_time_counted = [item[0] in self._global_counted_nodes_set for item in states_list] + target_states = target_states.flatten() + src_states_list = [states_list[i][3].flatten() for i in range(1, len(states_list))] + + for i in range(num_states): + target_sch_idx = (i % (target_multiplier * + aligned_shape[target_major_axis])) // target_multiplier + if node_time_counted[0]: + new_state = 0 + else: + new_state = target_states[i] + + for j in range(1, len(states_list)): + src_states = src_states_list[j - 1] + src_node_idx, src_major_axis, src_multiplier, _ = states_list[j] + src_sch_idx = (i % (src_multiplier * + aligned_shape[src_major_axis])) // src_multiplier + layout_transform_time = \ + self._global_layout_transform_interlayer_cost\ + [(src_node_idx, target_node_idx)][src_sch_idx][target_sch_idx] + + if node_time_counted[j]: + new_state += layout_transform_time + else: + new_state += layout_transform_time + src_states[i] + self._full_states[i] = new_state + + for i, node_counted in enumerate(node_time_counted): + if not node_counted: + self._global_counted_nodes_set.add(states_list[i][0]) + self._full_states = self._full_states.reshape(aligned_shape) + + # Remove dependency to reduce states + reduced_states = np.array(self._full_states) + reduced_states_transpose = [states_list[0][1]] + reduced_states_dep_list = [] + self._dep = [] + for i in range(len(reduced_states.shape)): + if i != states_list[0][1]: + reduced_states_transpose.append(i) + reduced_states_dep_list.append(aligned_node_list[i]) + reduced_states = np.transpose(reduced_states, reduced_states_transpose) + shift = 0 + for i, dep in enumerate(reduced_states_dep_list): + if dep not in self._global_dep_dict or len(self._global_dep_dict[dep]) == 1: + self._global_dep_dict.pop(dep, None) + reduced_states = np.amin(reduced_states, axis=i+1-shift) + shift += 1 + else: + self._dep.append(dep) + self._states = reduced_states + + # Update dependency + for dep in self._dep: + self._global_dep_dict[dep].remove(self._idx) + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[dep].add(child) + if len(self._global_out_nodes_dict[self._idx]) > 1: + self._global_dep_dict[self._idx] = set() + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[self._idx].add(child) + + @property + def dep(self): + """Get dependency list.""" + return self._dep + + @property + def states(self): + """Get states.""" + return self._states + + @property + def full_states(self): + """Get complete states.""" + return self._full_states + + @property + def full_states_idx(self): + """Get node index of complete states.""" + return self._full_states_idx + + @staticmethod + def align_states(input_index_list, stage_dict, node_list): + """Align all input node states shapes to be the same and transpose/reshape properly. + + This is used in creating multi_input operator states. + + Parameters + ---------- + input_index_list : list of int + List of input node index. + + stage_dict : dict of int to Stage + Global dictionary of node index to stage. + + node_list : list of dict + List of all nodes for current graph. + + Returns + ------- + states_list : list of tuple + List of aligned states. + + aligned_node_list : list in int + List of node index for aligned states. + """ + aligned_node_list = list(input_index_list) + states_list = [] + for input_idx in input_index_list: + input_node_stage = stage_dict[input_idx] + for dep_idx in input_node_stage.dep: + if dep_idx not in aligned_node_list: + aligned_node_list.append(dep_idx) + aligned_shape = tuple([len(node_list[idx]["record_candidates"]) + for idx in aligned_node_list]) + for input_idx in input_index_list: + input_node_stage = stage_dict[input_idx] + input_node_shape_idx_list = [input_idx] + input_node_stage.dep + transpose_idx_list = [] + reshape_list = [] + major_axis = -1 + for i, idx in enumerate(aligned_node_list): + if input_idx == idx: + major_axis = i + if idx in input_node_shape_idx_list: + transpose_idx_list.append(idx) + reshape_list.append(aligned_shape[i]) + else: + reshape_list.append(1) + transpose_list = [input_node_shape_idx_list.index(idx) for idx in transpose_idx_list] + input_node_states = np.transpose(input_node_stage.states, tuple(transpose_list)) + input_node_states = np.reshape(input_node_states, tuple(reshape_list)) + input_node_states = np.broadcast_to(input_node_states, aligned_shape) + multiplier = 1 + for i in range(major_axis + 1, len(aligned_shape)): + multiplier *= aligned_shape[i] + states_list.append((input_idx, major_axis, multiplier, input_node_states)) + return states_list, aligned_node_list diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py new file mode 100644 index 000000000000..11571f2bdef9 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable +"""Dynamic programming tuner.""" +import sys +import numpy as np + +from .base_graph_tuner import BaseGraphTuner +from .dynamic_programming_stage import DPStage +from .utils import has_multiple_inputs, is_input_node + +if sys.version_info[0] == 3: + import queue +else: + import Queue as queue + +class DPTuner(BaseGraphTuner): + """Tuner which uses dynamic programming to solve MDP problem. + + Note: currently dynamic programming is used to solve this MDP problem. However, + this problem is intrinsically non-polynomial. DP can't apply for more complicated + models, such as networks with many element-wise sum operators. In this case, switch + to heuristic algorithm such as PBQP tuner. + """ + def __init__(self, *args, **kwargs): + """Create a dynamic programming tuner. + """ + super(DPTuner, self).__init__(*args, **kwargs) + self._num_states = self._max_num_states = None + self._stage_dict = {} + self._dep_dict = {} + self._counted_nodes_set = set() + + self._global_data_dict = { + "dtype": self._dtype, + "counted_nodes_set": self._counted_nodes_set, + "stage_dict": self._stage_dict, + "in_nodes_dict": self._in_nodes_dict, + "out_nodes_dict": self._out_nodes_dict, + "dep_dict": self._dep_dict, + "node_list": self._node_list, + "input_shapes": self._input_shapes, + "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost + } + + def _check_num_states(self, num_states): + """Track the number of states.""" + self._num_states += num_states + if self._max_num_states is not None: + if self._num_states > self._max_num_states: + raise RuntimeError("Too many states detected while running dynamic " + "programming: got %d states but upper limit is %d." % + (self._num_states, self._max_num_states)) + + def _forward(self): + """Forward pass in DP to generate states for all stages. + """ + self._logger.info("Start forward pass...") + for node_idx in sorted(self._in_nodes_dict.keys()): + stage = DPStage(idx=node_idx, target_ops=self._target_ops, + **self._global_data_dict) + self._check_num_states(stage.full_states.size) + self._stage_dict[node_idx] = stage + self._logger.info("Finished forward pass.") + + def _backward(self): + """Backward pass in DP to generate optimal solution. + """ + self._logger.info("Start backward pass...") + input_names = self._input_shapes.keys() + optimal_record_dict = {} + # Pick optimal schedule for output nodes + output_idx_list = [] + for key, val in self._out_nodes_dict.items(): + if not val: + output_idx_list.append(key) + states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict, + self._node_list) + num_states = states_list[0][3].size + self._check_num_states(num_states * len(output_idx_list)) + aligned_node_shape = states_list[0][3].shape + min_time = 0 + min_pos = -1 + for states in states_list: + min_time += np.amax(states[3]) + flatten_states_list = [current_states[3].flatten() for current_states in states_list] + for i in range(num_states): + current_time = 0 + for j, current_states in enumerate(states_list): + current_time += flatten_states_list[j][i] + if min_time > current_time: + min_time = current_time + min_pos = i + for i, states in enumerate(states_list): + current_major_axis = states[1] + current_sch_idx = (min_pos % (states[2] * + aligned_node_shape[current_major_axis])) // states[2] + optimal_record_dict[aligned_node_list[i]] = current_sch_idx + # Pick optimal schedule for dependencies of output nodes + for i in range(len(states_list), len(aligned_node_list)): + multiplier = 1 + for j in range(i + 1, len(aligned_node_list)): + multiplier *= aligned_node_shape[j] + optimal_record_dict[aligned_node_list[i]] = \ + min_pos // multiplier % aligned_node_shape[i] + + # Backward pass to get optimal schedules for other nodes + bfs_q = queue.Queue() + visited = set() + for out_idx in output_idx_list: + bfs_q.put(out_idx) + while not bfs_q.empty(): + node_idx = bfs_q.get() + visited.add(node_idx) + if is_input_node(self._node_list[node_idx], input_names): + continue + optimal_sch_idx = optimal_record_dict[node_idx] + full_states = self._stage_dict[node_idx].full_states + if not has_multiple_inputs(self._node_list, node_idx, input_names): + input_idx = self._in_nodes_dict[node_idx][0] + if is_input_node(self._node_list[input_idx], input_names): + continue + if input_idx not in visited: + bfs_q.put(input_idx) + if input_idx not in optimal_record_dict: + dep_list = self._stage_dict[node_idx].dep + dep_idx = tuple([optimal_record_dict[item] for item in dep_list]) + tmp = np.argmin(full_states, axis=1) + optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx] + optimal_record_dict[input_idx] = optimal_input_sch_idx + else: + input_idx_list = self._in_nodes_dict[node_idx] + optimal_record_dict[input_idx_list[0]] = optimal_sch_idx + full_states_idx = self._stage_dict[node_idx].full_states_idx + tmp = full_states[optimal_sch_idx] + new_states_idx, new_states_pos = [], [] + visited_states_idx, visited_states_pos = [], [] + for i in range(1, len(full_states_idx)): + if full_states_idx[i] in optimal_record_dict: + visited_states_idx.append(full_states_idx[i]) + visited_states_pos.append(i - 1) + else: + new_states_idx.append(full_states_idx[i]) + new_states_pos.append(i - 1) + if visited_states_idx: + tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos)) + tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])] + min_pos = np.argmin(tmp) + multiplier = 1 + for i in range(len(new_states_idx)): + multiplier *= full_states.shape[new_states_pos[i] + 1] + for pos, idx in zip(new_states_pos, new_states_idx): + multiplier //= full_states.shape[pos + 1] + optimal_record_dict[idx] = min_pos // multiplier + min_pos %= multiplier + for input_idx in input_idx_list: + if input_idx not in visited: + bfs_q.put(input_idx) + + self._optimal_record_dict = optimal_record_dict + for node_idx, _ in self._in_nodes_dict.items(): + if self._node_list[node_idx]["op"] not in self._target_ops: + continue + self._logger.info("Finished backward pass...") + + def run(self, **kwargs): + """Run dynamic programming solver. + """ + max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"] + self._num_states = 0 + self._max_num_states = max_num_states + self._logger.info("Start to run dynamic programming algorithm...") + self._forward() + self._backward() + self._logger.info("Finished DPExecutor run.") diff --git a/python/tvm/autotvm/graph_tuner/pbqp_tuner.py b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py new file mode 100644 index 000000000000..1d7089ef248b --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py @@ -0,0 +1,288 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-locals +"""Partitioned Boolean Quadratic Programming Tuner""" +from ._base import INVALID_LAYOUT_TIME +from .base_graph_tuner import BaseGraphTuner +from .utils import is_input_node, has_multiple_inputs + + +class PBQPTuner(BaseGraphTuner): + """An approximation method to deal with intractably + large size of graph tuning problem. + + This graph coloring algorithm mainly comes from: + + Lang Hames and Bernhard Scholz. + Nearly optimal register allocation with pbqp.JMLC 2006. + LNCS, vol.4228,pp. 346-361, 2016 + """ + def __init__(self, *args, **kwargs): + """Create a partitioned boolean quadratic programming tuner. + """ + super(PBQPTuner, self).__init__(*args, **kwargs) + + # Remove input nodes + input_names = self._input_shapes.keys() + for node_idx in self._out_nodes_dict: + if is_input_node(self._node_list[node_idx], input_names): + for out_node_idx in self._out_nodes_dict[node_idx]: + self._in_nodes_dict[out_node_idx].remove(node_idx) + + self._adj_dict = {} + for node_idx in self._in_nodes_dict: + self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + \ + list(self._out_nodes_dict[node_idx]) + + self._record_cost_dict = {} + for key in self._in_nodes_dict: + self._record_cost_dict[key] = [] + for record in self._node_list[key]["record_candidates"]: + self._record_cost_dict[key].append(record[1].costs[0]) + + self._max_degree = -1 + self._node_degree_dict = {} + for node_idx in self._in_nodes_dict: + node_degree = self._get_degree(node_idx) + self._node_degree_dict[node_idx] = node_degree + self._max_degree = max(self._max_degree, node_degree) + + self._stack = [] + self._buckets = [[] for _ in range(self._max_degree + 2)] + for node_idx in sorted(self._in_nodes_dict): + node_degree = self._get_degree(node_idx) + self._buckets[node_degree].append(node_idx) + + self._is_optimal = True + + def _get_degree(self, node_idx): + """Get node degree. + """ + return len(self._adj_dict[node_idx]) + + def _reorder_adj_nodes(self, node_idx): + """Update buckets list with current adjacency list. + """ + for adj_node in self._adj_dict[node_idx]: + current_degree = self._get_degree(adj_node) + prev_degree = self._node_degree_dict[adj_node] + if prev_degree != current_degree: + self._buckets[prev_degree].remove(adj_node) + self._buckets[current_degree].insert(0, adj_node) + self._node_degree_dict[adj_node] = current_degree + + def _remove_node(self, node_idx): + """Remove node from graph. Update adjacency list accordingly. + """ + node_degree = self._get_degree(node_idx) + self._buckets[node_degree].remove(node_idx) + for adj_node in self._adj_dict[node_idx]: + self._adj_dict[adj_node].remove(node_idx) + + def _insert_edge(self, node_x, node_y, adj_cost_matrix): + """Insert an edge between two nodes. + """ + self._layout_transform_interlayer_cost[(node_x, node_y)] = adj_cost_matrix + self._layout_transform_interlayer_cost[(node_y, node_x)] = [] + for i in range(len(adj_cost_matrix[0])): + self._layout_transform_interlayer_cost[(node_y, node_x)].append([]) + for cost_vec in adj_cost_matrix: + self._layout_transform_interlayer_cost[(node_y, node_x)][i] \ + .append(cost_vec[i]) + + self._adj_dict[node_x].append(node_y) + self._adj_dict[node_y].append(node_x) + + def _backward_insert_node(self, node_idx): + """Reinsert node in backward pass. + """ + for adj_node in self._adj_dict[node_idx]: + self._adj_dict[adj_node].append(node_idx) + + def _RI_reduction(self, node_idx): + """Reduce nodes with degree 1. + """ + adj_node = self._adj_dict[node_idx][0] + ltf_matrix = self._layout_transform_interlayer_cost[(adj_node, node_idx)] + for i, cost_vec in enumerate(ltf_matrix): + min_cost = INVALID_LAYOUT_TIME + for j, cost in enumerate(cost_vec): + min_cost = min(min_cost, cost + self._record_cost_dict[node_idx][j]) + self._record_cost_dict[adj_node][i] += min_cost + self._remove_node(node_idx) + self._reorder_adj_nodes(node_idx) + self._stack.append(node_idx) + + def _RII_reduction(self, node_idx): + """Reduce nodes with degree 2. + """ + adj_node_x, adj_node_y = self._adj_dict[node_idx] + ltf_matrix_x = self._layout_transform_interlayer_cost[(adj_node_x, node_idx)] + ltf_matrix_y = self._layout_transform_interlayer_cost[(adj_node_y, node_idx)] + delta_matrix = [[] for _ in range(len(ltf_matrix_x))] + for i, cost_vec_x in enumerate(ltf_matrix_x): + for j, cost_vec_y in enumerate(ltf_matrix_y): + min_cost = INVALID_LAYOUT_TIME + for k in range(len(self._record_cost_dict[node_idx])): + min_cost = min(min_cost, cost_vec_x[k] + cost_vec_y[k] + + self._record_cost_dict[node_idx][k]) + delta_matrix[i].append(min_cost) + + if adj_node_x == adj_node_y: + for i, delta_row in enumerate(delta_matrix): + self._record_cost_dict[adj_node_x][i] += delta_row[i] + elif adj_node_x in self._adj_dict[adj_node_y]: + for i, _ in enumerate(delta_matrix): + for j, delta in enumerate(delta_matrix[i]): + self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] \ + += delta + self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] \ + += delta + else: + self._insert_edge(adj_node_x, adj_node_y, delta_matrix) + + self._remove_node(node_idx) + self._reorder_adj_nodes(node_idx) + self._stack.append(node_idx) + + def _RN_reduction(self, node_idx): + """Reduce nodes with degree greater than 2. + """ + min_cost = INVALID_LAYOUT_TIME + record_idx = -1 + + for i, record_cost in enumerate(self._record_cost_dict[node_idx]): + current_cost = record_cost + for adj_node in self._adj_dict[node_idx]: + ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)] + adj_record_cost = list(self._record_cost_dict[adj_node]) + for j, ltf_cost in enumerate(ltf_matrix[i]): + adj_record_cost[j] += ltf_cost + current_cost += min(adj_record_cost) + if current_cost < min_cost: + min_cost = current_cost + record_idx = i + + if record_idx < 0: + raise RuntimeError("Can't find a soltuion for node %d when " + "applying RN reduction" % node_idx) + self._optimal_record_dict[node_idx] = record_idx + self._is_optimal = False + + for adj_node in self._adj_dict[node_idx]: + ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)] + for i, ltf_cost in enumerate(ltf_matrix[record_idx]): + self._record_cost_dict[adj_node][i] += ltf_cost + + self._remove_node(node_idx) + self._reorder_adj_nodes(node_idx) + self._stack.append(node_idx) + + def _forward(self): + """Forward pass in PBQP to reduce nodes. + """ + while True: + if self._buckets[1]: + node_idx = self._buckets[1][0] + self._RI_reduction(node_idx) + elif self._max_degree >= 2 and self._buckets[2]: + node_idx = self._buckets[2][0] + self._RII_reduction(node_idx) + elif self._max_degree >= 3: + max_degree_node = -1 + for i in range(self._max_degree, 2, -1): + if self._buckets[i]: + max_degree_node = self._buckets[i][0] + self._RN_reduction(max_degree_node) + break + if max_degree_node < 0: + break + else: + break + + def _backward(self): + """Backward pass in PBQP to generate optimal solution. + """ + # Solve nodes left in the forward graph + for node_idx in self._buckets[0]: + record_costs = self._record_cost_dict[node_idx] + min_cost = min(record_costs) + self._optimal_record_dict[node_idx] = record_costs.index(min_cost) + + # Solve nodes with one or two degrees + for node_idx in reversed(self._stack): + self._backward_insert_node(node_idx) + if node_idx not in self._optimal_record_dict: + record_costs = list(self._record_cost_dict[node_idx]) + for adj_node in self._adj_dict[node_idx]: + adj_optimal_idx = self._optimal_record_dict[adj_node] + for i, _ in enumerate(record_costs): + record_costs[i] += \ + self._layout_transform_interlayer_cost \ + [(node_idx, adj_node)][i][adj_optimal_idx] + min_cost = min(record_costs) + self._optimal_record_dict[node_idx] = record_costs.index(min_cost) + + def run(self, **kwargs): + """Run partitioned boolean quadratic programming tuner. + """ + self._logger.info("Start to run PBQP algorithm...") + # Define virtual record lists and layout transformaton matrices + # for multi-input nodes. + input_names = self._input_shapes.keys() + temp = {} + for key, val in self._in_nodes_dict.items(): + target_input_idx = -1 + target_input_pos = -1 + if has_multiple_inputs(self._node_list, key, input_names): + for i, item in enumerate(val): + if not is_input_node(self._node_list[item], input_names): + target_input_idx = item + target_input_pos = i + break + temp[(target_input_idx, key)] = [] + record_candidates = self._node_list[target_input_idx]["record_candidates"] + for j in range(len(record_candidates)): + temp[(target_input_idx, key)].append([]) + for k in range(len(record_candidates)): + temp[(target_input_idx, key)][j].append(0 if j == k + else INVALID_LAYOUT_TIME) + + for j in range(target_input_pos + 1, len(val)): + input_idx = val[j] + if is_input_node(self._node_list[input_idx], input_names): + continue + temp[(input_idx, key)] = \ + self._layout_transform_interlayer_cost[(input_idx, target_input_idx)] + self._layout_transform_interlayer_cost.update(temp) + + # Create reverse layout transformation matrices + temp = {} + for idx_pair, ltf_matrix in self._layout_transform_interlayer_cost.items(): + reverse_key = (idx_pair[1], idx_pair[0]) + reverse_matrix = [[] for _ in range(len(ltf_matrix[0]))] + for i, _ in enumerate(ltf_matrix): + for j, ltf in enumerate(ltf_matrix[i]): + reverse_matrix[j].append(ltf) + temp[reverse_key] = reverse_matrix + self._layout_transform_interlayer_cost.update(temp) + + self._forward() + self._backward() + is_optimal = "optimal" if self._is_optimal else "sub-optimal" + msg = "Finished PBQPExecutor run. Got %s solution." % is_optimal + self._logger.info(msg) diff --git a/python/tvm/autotvm/graph_tuner/utils/__init__.py b/python/tvm/autotvm/graph_tuner/utils/__init__.py new file mode 100644 index 000000000000..8b36e752bdef --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/utils/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Graph tuner utility functions""" +from __future__ import absolute_import + +from . import traverse_graph +from . import utils + +from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \ + get_out_nodes +from .utils import has_multiple_inputs, is_input_node, bind_inputs diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py new file mode 100644 index 000000000000..08f1017e7fb8 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-locals,too-many-statements,too-many-branches,protected-access +"""API for graph traversing.""" +import threading + +import topi + +from tvm import relay, autotvm +from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.ty import TupleType, TensorType +from tvm.autotvm.task import TaskExtractEnv + +from .._base import RULE_OUT_NODE_NAMES +from .utils import has_multiple_inputs, is_input_node + + +# Setup relay op base name -> topi compute functions +# NOTE: To add more ops, change the following dictionary. +OP2COMPUTE = { + "conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw], +} + + +def expr2graph(expr, target_ops, node_dict, node_list): + """Convert relay expr to graph data structure + and fetch workloads of target operators. + + Parameters + ---------- + expr : tvm.relay.Expr.Function + Input relay function expression. + + target_ops: List of str + List of target relay base op name + + node_dict : dictionary from tvm.relay.Expr to int + Dictionary to record node index + + node_list : list of dictionary + List of nodes which contains all expr in the input relay function. + Each node will be stored as a dictionary in the format of + {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type], + "name": str, "workloads": [tuple], "topi_op": [function]} + """ + env = TaskExtractEnv.get(allow_duplicate=True) + topi_funcs = [] + for op_name in target_ops: + if op_name not in OP2COMPUTE: + raise RuntimeError("Not supported relay op in graph tuner: %s" + % op_name) + topi_funcs += OP2COMPUTE[op_name] + env.reset(topi_funcs) + _expr2graph_impl(expr, target_ops, node_dict, node_list) + task_pos = 0 + for node_entry in node_list: + if node_entry["op"] in target_ops: + task_name, args = env.task_collection[task_pos] + task = autotvm.task.create(task_name, args, + target="llvm", + target_host=None, + template_key='direct') + node_entry["workloads"] = [task.workload] + node_entry["topi_op"] = [task_name] + task_pos += 1 + + +def _expr2graph_impl(expr, target_ops, node_dict, node_list): + """Implementation to convert relay expr to graph data structure + """ + def _traverse_expr(node): + if node in node_dict: + return + node_index = len(node_list) + node_entry = {"node": node, "inputs": [], "types": [], + "op": "null", "name": None} + + if isinstance(node, Call): + op_name = node.op.name.split(".")[-1] + node_entry["op"] = op_name + for arg in node.args: + in_node_idx = node_dict[arg] + if isinstance(arg, (Tuple, TupleGetItem)): + node_entry["inputs"] += node_list[in_node_idx]["inputs"] + else: + node_entry["inputs"].append([in_node_idx, 0, 0]) + infer_out = relay.ir_pass.infer_type(node) + out_type = infer_out._checked_type_ + if isinstance(out_type, TensorType): + node_entry["types"].append(out_type) + elif isinstance(out_type, TupleType): + for tupe_type in out_type.fields: + node_entry["types"].append(tupe_type) + else: + raise RuntimeError("Unsupported output type %s in operator %s" + % (type(out_type), op_name)) + + # Utilize tracing target to fetch workload with topo-order. + # Since we only need workload, dummy target can be used to + # create task. + if op_name in target_ops: + params = [] + for i, input_idx in enumerate(node_entry["inputs"]): + input_node_entry = node_list[input_idx[0]] + input_type = input_node_entry["types"][input_idx[1]] + if not isinstance(input_node_entry["node"], (Var, Call)): + raise RuntimeError("Graph tuner can only tune target " + "operators with input node of type " + "relay.expr.Var or relay.expr.Call. Now " + "find a target op %s with input type %s" + % (op_name, str(type(input_node_entry["node"])))) + free_var = relay.Var("var_%d" % i, input_type) + params.append(free_var) + call = relay.Call(node.op, params, node.attrs) + func = relay.Function(params, call) + relay.backend.compile_engine.get().clear() + build_thread = threading.Thread(target=relay.build, + args=(func, + "llvm -device=tracing", + None, + None)) + build_thread.start() + build_thread.join() + elif isinstance(node, Var): + node_entry["name"] = node.name_hint + node_entry["types"] = [node.type_annotation] + elif isinstance(node, Function): + # Ignore root node since it equals to input function expression + if node != expr: + _expr2graph_impl(node, target_ops, node_dict, node_list) + return + elif isinstance(node, TupleGetItem): + node_entry["op"] = "TupleGetItem" + in_node_idx = node_dict[node.tuple_value] + node_entry["inputs"].append([in_node_idx, node.index, 0]) + elif isinstance(node, Tuple): + node_entry["op"] = "Tuple" + for tuple_item in node: + in_node_idx = node_dict[tuple_item] + if isinstance(tuple_item, TupleGetItem): + node_entry["inputs"] += node_list[in_node_idx]["inputs"] + elif isinstance(tuple_item, Tuple): + raise RuntimeError("Graph tuner doesn't support nested tuple.") + else: + node_entry["inputs"].append([in_node_idx, 0, 0]) + elif isinstance(node, Constant): + pass + elif isinstance(node, relay.op.op.Op): + return + else: + raise RuntimeError("Not supported relay node type in graph tuning: %s" + % str(type(node))) + node_dict[node] = node_index + node_list.append(node_entry) + + relay.ir_pass.post_order_visit(expr, _traverse_expr) + + +def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names): + """Given a node_list in relay function and a node index, return the + closest ancestor which has op_name as operator name or is multi_input operator. + + If node has multiple inputs, multiple ancestor nodes will be returned. + + Parameters + ---------- + node_list : list of dict of str to object + List of all nodes in a graph. + + visited_dict : dict of int to int + Nodes and corresponding ancestors which have been visited. + + target_ops: List of str + List of target relay base op name + + node_idx : int + Input node index. + + input_names : list of str + Names of graph input nodes. + + Returns + ------- + out : list of int + List of ancestor node index. + """ + if node_idx in visited_dict: + return visited_dict[node_idx] + if is_input_node(node_list[node_idx], input_names): + return [node_idx] + node = node_list[node_idx] + # Rule out injective operators + is_rule_out = False + for item_idx in node["inputs"]: + item = node_list[item_idx[0]] + if item["op"] in RULE_OUT_NODE_NAMES: + is_rule_out = True + break + if is_rule_out: + visited_dict[node_idx] = [] + return [] + + node_direct_ancestor = [] + for item_idx in node["inputs"]: + item = node_list[item_idx[0]] + is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names) + if item["op"] in target_ops or is_multiple_inputs: + node_direct_ancestor.append(item_idx[0]) + else: + tmp = get_direct_ancestor(node_list, visited_dict, target_ops, + item_idx[0], input_names) + for tmp_item in tmp: + node_direct_ancestor.append(tmp_item) + if not has_multiple_inputs(node_list, node_idx, input_names) and node_direct_ancestor: + node_direct_ancestor = [node_direct_ancestor[0]] + visited_dict[node_idx] = node_direct_ancestor + return node_direct_ancestor + + +def get_in_nodes(node_list, target_ops, input_names): + """Create a dictionary mapping from op_name nodes or multi_input + nodes to closest input ancestors. + + Parameters + ---------- + node_list : list of dict of str to object + List of all nodes in a graph. + + target_ops: List of str + List of target relay op + + input_names : list of str + Names of graph input nodes. + + Returns + ------- + out : dict of int to list of int + Dictionary maps node index to closest input ancestors. + """ + + visited_dict = {} + in_node_dict = {} + for i, node in enumerate(node_list): + if node["op"] in RULE_OUT_NODE_NAMES: + continue + get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names) + for key, val in visited_dict.items(): + node = node_list[key] + is_multiple_inputs = has_multiple_inputs(node_list, key, input_names) + if node["op"] in target_ops or is_multiple_inputs: + in_node_dict[key] = val + + # Remove empty nodes + has_empty_node = True + out_node_dict = get_out_nodes(in_node_dict) + while has_empty_node: + empty_nodes = [] + for key, val in in_node_dict.items(): + if not val: + empty_nodes.append(key) + if empty_nodes: + has_empty_node = True + for node in empty_nodes: + del in_node_dict[node] + if node in out_node_dict: + for out_node in out_node_dict[node]: + in_node_dict[out_node].remove(node) + else: + has_empty_node = False + + return in_node_dict + + +def get_out_nodes(in_node_dict): + """Create output dictionary from input dictionary. + + Parameters + ---------- + in_node_dict : dict of int to list of int + Dictionary maps node index to closest input ancestors. + It can be created with get_in_nodes. + + Returns + ------- + out : dict of int to list of int + Dictionary maps node index to closest output nodes. + """ + out_node_dict = {} + for key in in_node_dict: + out_node_dict[key] = [] + for key, val in in_node_dict.items(): + for item in val: + if item in out_node_dict: + out_node_dict[item].append(key) + else: + out_node_dict[item] = [key] + + return out_node_dict diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py new file mode 100644 index 000000000000..6151734299af --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used,invalid-name,too-many-arguments +"""Utility functions""" +from tvm import relay + + +def has_multiple_inputs(node_list, node_idx, input_names): + """Check whether a node has multiple input nodes + except variable nodes. + + Parameters + ---------- + node_list : list of dict of str to object + List of all nodes in a graph. + + node_idx : int + Node index to be checked. + + input_names : list of str + List of input names of graph. + + Returns + ------- + out : bool + Whether the specified node has multiple input nodes + """ + num_inputs = 0 + node = node_list[node_idx] + for in_idx in node["inputs"]: + in_idx = in_idx[0] + in_node = node_list[in_idx] + # Exclude parameter nodes + if in_node["op"] != "null" or is_input_node(in_node, + input_names): + num_inputs += 1 + return num_inputs > 1 + + +def is_input_node(node_entry, input_names): + """Whether a node is an input node. + + Parameters + ---------- + node_entry : dict + Node entry. + + input_names : list of str + List of input names of graph. + + Returns + ------- + out : bool + whether node is a input node. + """ + return "name" in node_entry and node_entry["name"] in input_names + + +def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): + """Bind input variables of a relay function expression + to new shapes and/or dtypes. + + Parameters + ---------- + expr : tvm.relay.Expr.Function + Input relay function expression. + + input_shapes : dict of str to tuple of int, optional + Input shapes. + + input_dtypes : str or dict of str to str, optional + Input dtypes. + + Returns + ------- + out : tvm.relay.Expr.Function + Bind relay function expression. + """ + if input_shapes is None: + return expr + if isinstance(input_dtypes, str): + input_dtypes = {key : input_dtypes for key in input_shapes.keys()} + + updated_input_dict = {} + for input_name in input_shapes.keys(): + updated_input = relay.var(input_name, shape=input_shapes[input_name], + dtype=input_dtypes[input_name]) + updated_input_dict[input_name] = updated_input + + rebind_dict = {} + for var in expr.params: + if var.name_hint in updated_input_dict: + rebind_dict[var] = updated_input_dict[var.name_hint] + updated_expr = relay.expr.bind(expr, rebind_dict) + + return relay.ir_pass.infer_type(updated_expr) diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index ff50a4ebc81d..0a0e6e1e8ac7 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -28,6 +28,7 @@ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \ FallbackContext, clear_fallback_cache, ApplyGraphBest -from .topi_integration import register_topi_compute, register_topi_schedule +from .topi_integration import register_topi_compute, register_topi_schedule, \ + TaskExtractEnv from .nnvm_integration import extract_from_graph, extract_from_multiple_graph from .relay_integration import extract_from_program, extract_from_multiple_program diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 3c983768ab3e..ef0cb568071c 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -74,7 +74,7 @@ class TaskExtractEnv: """Global environment for extracting tuning tasks from nnvm graph""" current = None - def __init__(self): + def __init__(self, allow_duplicate=False): import topi # topi compute -> autotvm task name @@ -106,6 +106,7 @@ def __init__(self): topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], } + self.allow_duplicate = allow_duplicate self._register_tracing() self._register_topi_task() self.task_collection = [] @@ -123,10 +124,9 @@ def _tracing_topi_compute(*args, **kwargs): assert not kwargs, "Do not support extracting tuning tasks when" \ "kwargs is used in TOPI function call." \ "Please modify it to use only positional args." - if compute_func in self.wanted_topi_funcs: # record this call key = (self.topi_to_task[compute_func], serialize_args(args)) - if key not in self.task_collection: + if self.allow_duplicate or key not in self.task_collection: self.task_collection.append(key) return compute_func.fdefault(*args) _local_scope(topi_compute) @@ -262,16 +262,25 @@ def get_tasks(self): return self.task_collection @staticmethod - def get(): + def get(allow_duplicate=False): """Get the single instance of TaskExtractEnv + Parameters + ---------- + allow_duplicate : boolean + Whether to fetch all workloads in the network, + even though some of them are the same. This is + useful for graph tuning. + Returns ------- env: TaskExtractEnv The single instance of TaskExtractEnv """ if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv() + TaskExtractEnv.current = TaskExtractEnv(allow_duplicate) + else: + TaskExtractEnv.current.allow_duplicate = allow_duplicate return TaskExtractEnv.current diff --git a/tests/python/unittest/test_graph_tuner_core.py b/tests/python/unittest/test_graph_tuner_core.py new file mode 100644 index 000000000000..240da7f88628 --- /dev/null +++ b/tests/python/unittest/test_graph_tuner_core.py @@ -0,0 +1,254 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: We name this test file to start with test_graph_tuner +# to make it execute after zero_rank tensor test cases. This +# helps avoid topi arithmetic operator overloading issue: +# https://github.com/dmlc/tvm/issues/3240. +# TODO: restore the file name after this issue is resolved. +import os +import copy +import numpy as np +import tvm +import tvm.relay.testing + +from tvm import autotvm +from tvm import relay +from tvm.autotvm.task import ConfigEntity +from tvm.autotvm.measure import MeasureResult, MeasureInput +from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner +from test_graph_tuner_utils import create_workload + + +def _create_data(target, dshape, dtype, layout): + data = relay.var("data", shape=dshape, dtype=dtype) + w0 = relay.var("w0_weight") + conv0 = relay.nn.conv2d(data, w0, channels=16, kernel_size=(3, 3), padding=(1, 1)) + w1 = relay.var("w1_weight") + conv1 = relay.nn.conv2d(conv0, w1, channels=32, kernel_size=(1, 1)) + w2 = relay.var("w2_weight") + conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1)) + out = relay.add(conv1, conv2) + net = relay.Function(relay.ir_pass.free_vars(out), out) + net, params = relay.testing.create_workload(net) + tasks = autotvm.task.extract_from_program(net, + target=target, + params=params, + ops=(relay.op.nn.conv2d,)) + wkl_list = [ + create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype), + create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0), (1, 1), layout, layout, dtype, dtype), + create_workload((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype), + ] + costs = [0.04, 0.012, 0.03] + config_list = [] + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [3, 1]], + ["tile_oc", "sp", [4, 4]], + ["tile_ow", "sp", [4, 2]], + ["unroll_kw", "ot", True]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [2, 8]], + ["tile_oc", "sp", [1, 32]], + ["tile_oh", "ot", 1], + ["tile_ow", "sp", [4, 2]]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [8, 4]], + ["tile_oc", "sp", [4, 8]], + ["tile_ow", "sp", [2, 4]], + ["unroll_kw", "ot", False]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + + records = [] + for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks): + task.workload = wkl + ms_input = MeasureInput(target=target, task=task, config=config) + ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) + records.append((ms_input, ms_output)) + + ltf_records = [] + ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"] + ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) + ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) + ltf_task = copy.deepcopy(tasks[0]) + ltf_task.workload = ltf_wkl + ms_input = MeasureInput(target=target, task=ltf_task, config=None) + ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1) + ltf_records.append((ms_input, ms_output)) + + ltf_keys = [] + ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"] + ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) + ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) + ltf_keys.append(ltf_wkl) + ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"] + ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) + ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) + ltf_keys.append(ltf_wkl) + ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"] + ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) + ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) + ltf_keys.append(ltf_wkl) + + return net, records, ltf_records, ltf_keys, tasks + + +def test_graph_tuner_layout_transform(): + log_file = "%s/test_tuner.log" % (os.getcwd()) + target = "llvm" + dshape = (1, 3, 8, 8) + dtype = "float32" + layout = "NCHW" + target_ops = [relay.nn.conv2d] + + g, records, ltf_records, ltf_keys, _ = _create_data(target, dshape, dtype, layout) + executor = DPTuner(g, {"data": dshape}, records, target_ops, target=target, log_file=log_file) + executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) + out = executor._layout_transform_perf_records + + num_flops = 0 + total_time = 0 + for record in ltf_records: + ltf_wkl = record[0].task.workload + input_shape = ltf_wkl[1][1] + flops = np.prod(input_shape) + num_flops += flops + total_time += record[1].costs[0] + avg_time = total_time / num_flops + + for ltf_workload in out: + input_shape = ltf_workload[1][1] + flops = 1 + for i in input_shape: + flops *= i + expected_time = flops * avg_time + out_time = out[ltf_workload][1].costs[0] + assert expected_time == out_time, "Inferred layout transformation time mismatch for %s: " \ + "expecting %f but got %f" % (str(ltf_workload), expected_time, + out_time) + + +def test_DPTuner_run(): + log_file = "%s/test_tuner.log" % (os.getcwd()) + target = "llvm" + dtype = "float32" + layout = "NCHW" + dshape = (1, 3, 8, 8) + target_ops = [relay.nn.conv2d] + + g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout) + costs = [0.02, 0.02, 0.045] + config_list = [] + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [1, 3]], + ["tile_oc", "sp", [2, 8]], + ["tile_ow", "sp", [4, 2]], + ["unroll_kw", "ot", True]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [4, 4]], + ["tile_oc", "sp", [2, 16]], + ["tile_oh", "ot", 1], + ["tile_ow", "sp", [4, 2]]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [16, 2]], + ["tile_oc", "sp", [8, 4]], + ["tile_ow", "sp", [2, 4]], + ["unroll_kw", "ot", False]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + for cost, config, task in zip(costs, config_list, tasks): + ms_input = MeasureInput(target=target, task=task, config=config) + ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) + records.append((ms_input, ms_output)) + + executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file) + executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) + executor.run() + out = [record[0].config for record in executor.get_optimal_records()] + expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] + assert expected_out == out, "Output mismatch: expecting %s but got %s" \ + % (str(expected_out), str(out)) + assert os.path.isfile(log_file), "No log file with name %s exists." % log_file + + +def test_PBQPTuner_run(): + target = "llvm" + dtype = "float32" + layout = "NCHW" + dshape = (1, 3, 8, 8) + target_ops = [relay.nn.conv2d] + + g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout) + costs = [0.02, 0.02, 0.045] + config_list = [] + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [1, 3]], + ["tile_oc", "sp", [2, 8]], + ["tile_ow", "sp", [4, 2]], + ["unroll_kw", "ot", True]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [4, 4]], + ["tile_oc", "sp", [2, 16]], + ["tile_oh", "ot", 1], + ["tile_ow", "sp", [4, 2]]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [16, 2]], + ["tile_oc", "sp", [8, 4]], + ["tile_ow", "sp", [2, 4]], + ["unroll_kw", "ot", False]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + for cost, config, task in zip(costs, config_list, tasks): + ms_input = MeasureInput(target=target, task=task, config=config) + ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) + records.append((ms_input, ms_output)) + + executor = PBQPTuner(g, {"data": dshape}, records, target_ops, target) + executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) + executor.run() + out = [record[0].config for record in executor.get_optimal_records()] + expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] + assert expected_out == out, "Output mismatch: expecting %s but got %s" \ + % (str(expected_out), str(out)) + + +if __name__=="__main__": + test_graph_tuner_layout_transform() + test_DPTuner_run() + test_PBQPTuner_run() diff --git a/tests/python/unittest/test_graph_tuner_utils.py b/tests/python/unittest/test_graph_tuner_utils.py new file mode 100644 index 000000000000..0847166412d2 --- /dev/null +++ b/tests/python/unittest/test_graph_tuner_utils.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: We name this test file to start with test_graph_tuner +# to make it execute after zero_rank tensor test cases. This +# helps avoid topi arithmetic operator overloading issue: +# https://github.com/dmlc/tvm/issues/3240 +# TODO: restore the file name after this issue is resolved. +import tvm + +from tvm import autotvm, relay +from tvm.relay.testing import resnet +from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ + get_out_nodes, expr2graph, bind_inputs +from tvm.relay.expr import Call, TupleGetItem, Tuple +from topi.nn.conv2d import conv2d + + +def create_workload(dshape, kshape, strides, + padding, dilation, layout, + out_layout, dtype, out_dtype): + data = tvm.placeholder(dshape, dtype=dtype) + kernel = tvm.placeholder(kshape, dtype=dtype) + return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout, + out_dtype], conv2d) + + +def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result): + out = has_multiple_inputs(node_list, node_idx, input_names) + assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \ + % (node_list[node_idx]["op"], str(expected_result), str(out)) + + +def test_has_multiple_inputs(): + data = relay.var("data") + out1 = data * relay.expr.const(3.0) + w0 = relay.var("w0") + out2 = relay.nn.conv2d(data, w0) + out = relay.add(out1, out2) + net = relay.Function(relay.ir_pass.free_vars(out), out) + net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)}) + target_ops = ["conv2d"] + node_list = [] + node_dict = {} + expr2graph(net, target_ops, node_dict, node_list) + input_names = ["data"] + verify_has_multiple_inputs(node_list, 2, input_names, False) + verify_has_multiple_inputs(node_list, 4, input_names, False) + verify_has_multiple_inputs(node_list, 5, input_names, True) + + +def test_expr2graph(): + net, _ = resnet.get_workload(num_layers=50, batch_size=1) + node_dict = {} + node_list = [] + target_ops = ["conv2d"] + op_name_list = [] + def _count_node(node): + if not isinstance(node, relay.op.op.Op,): + return + if isinstance(node, Call): + op_name_list.append(node.op.name.split(".")[-1]) + elif isinstance(node, TupleGetItem): + op_name_list.append("TupleGetItem") + elif isinstance(node, Tuple): + op_name_list.append("Tuple") + else: + op_name_list.append("null") + relay.ir_pass.post_order_visit(net, _count_node) + + expr2graph(net, target_ops, node_dict, node_list) + for i, item in enumerate(zip(op_name_list, node_list)): + op_name, node = item + assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \ + % (i, str(op_name), str(node["op"])) + + +def test_get_direct_ancestor(): + data = relay.var("data") + w0 = relay.var("w0") + out1 = relay.nn.conv2d(data, w0) + out2 = relay.add(out1, data * relay.expr.const(5.0)) + out3 = out2 + relay.expr.const(2.5) + w1 = relay.var("w1") + out = relay.nn.conv2d(out3, w1) + net = relay.Function(relay.ir_pass.free_vars(out), out) + net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)}) + target_ops = ["conv2d"] + node_list = [] + node_dict = {} + expr2graph(net, target_ops, node_dict, node_list) + visited_dict = {} + input_names = ["data"] + out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names) + assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out) + + +def test_get_in_nodes(): + data = relay.var("data") + w0 = relay.var("w0") + out1 = relay.nn.conv2d(data, w0) + out2 = relay.add(out1, data) + out3 = out2 + relay.expr.const(2.5) + w1 = relay.var("w1") + out = relay.nn.conv2d(out3, w1) + net = relay.Function(relay.ir_pass.free_vars(out), out) + net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)}) + target_ops = ["conv2d"] + input_names = ["data"] + node_list = [] + node_dict = {} + expr2graph(net, target_ops, node_dict, node_list) + out = get_in_nodes(node_list, target_ops, input_names) + expected_out = {7: [3], 3: [2, 0], 2: [0]} + diff_set = set(out) ^ set(expected_out) + if len(diff_set) != 0: + raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out))) + + +def test_get_out_nodes(): + in_nodes_dict = {8: [4], 4: [3, 0], 3: [0]} + expected_out = {0: [3, 4], 3: [4], 4: [8], 8: []} + out = get_out_nodes(in_nodes_dict) + diff_set = set(out) ^ set(expected_out) + if len(diff_set) != 0: + raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out))) + + + +if __name__ == "__main__": + test_has_multiple_inputs() + test_expr2graph() + test_get_direct_ancestor() + test_get_in_nodes() + test_get_out_nodes() diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 83e0274597d7..57c1d20c422f 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -94,6 +94,26 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): # not to change by default return None +@tvm.target.generic_func +def conv2d_infer_layout(workload, cfg): + """Infer input/output shapes and layouts from a workload and cfg. + + Parameters + ---------- + workload : tuple + conv2d workload + + cfg : tuple + tvm.autotvm config + + Returns + ------- + Output : [tuple of tuple and str, tuple of tuple and str] + Input shapes and layouts, and output shapes and layouts + """ + raise ValueError("missing register for topi.nn.conv2d_infer_layout") + + def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py index 460f4fefe678..e703becf6d40 100644 --- a/topi/python/topi/nn/depthwise_conv2d.py +++ b/topi/python/topi/nn/depthwise_conv2d.py @@ -336,3 +336,22 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation, 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] """ raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc") + +@tvm.target.generic_func +def depthwise_conv2d_infer_layout(workload, cfg): + """Infer input/output shapes and layouts from a workload and cfg. + + Parameters + ---------- + workload : tuple + conv2d workload + + cfg : tuple + tvm.autotvm config + + Returns + ------- + Output : [tuple of tuple and str, tuple of tuple and str] + Input shapes and layouts, and output shapes and layouts + """ + raise ValueError("missing register for topi.nn.depthwise_conv2d_infer_layout") diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index c333892a9918..d9831c8347f3 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -28,7 +28,7 @@ from .. import nn from ..util import get_const_tuple from ..nn.conv2d import conv2d, conv2d_NCHWc, \ - conv2d_alter_layout, _get_workload as _get_conv2d_workload + conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.pad import pad @@ -480,6 +480,21 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) +@conv2d_infer_layout.register("cpu") +def _conv2d_infer_layout(workload, cfg): + _, data, kernel, strides, padding, dilation, layout, dtype = workload + batch_size, in_channel, in_height, in_width = data[:-1] + out_channel, _, k_height, k_width = kernel[:-1] + out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 + out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 + tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) + in_layout = "NCHW%dc" % tile_ic + out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) + out_layout = "NCHW%dc" % tile_oc + return ((in_shape, in_layout),), ((out_shape, out_layout),) + + @autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct') def _declaration_conv_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index f570aaf7e70d..6ea11f234759 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -25,7 +25,8 @@ from ..nn.pad import pad from ..util import get_const_tuple from ..nn.util import get_pad_tuple -from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload +from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \ + depthwise_conv2d_infer_layout from .util import get_fp32_len @@ -206,7 +207,7 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): # change shape with the value in config ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn) - new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn) + new_kernel_shape = (out_channel // oc_bn, 1, kh, kw, 1, oc_bn) new_data = tvm.placeholder(new_data_shape, data.dtype) new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) @@ -217,3 +218,18 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): data_layout, out_layout, dtype) s = schedule_depthwise_conv2d_NCHWc(cfg, [C]) return s, [new_data, new_kernel, C] + +@depthwise_conv2d_infer_layout.register("cpu") +def _depthwise_conv2d_infer_layout(workload, cfg): + _, data, kernel, strides, padding, dilation, dtype = workload + batch_size, in_channel, in_height, in_width = data[:-1] + filter_channel, channel_multiplier, k_height, k_width = kernel[:-1] + out_channel = filter_channel * channel_multiplier + out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 + out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 + tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) + in_layout = "NCHW%dc" % tile_ic + out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) + out_layout = "NCHW%dc" % tile_oc + return ((in_shape, in_layout),), ((out_shape, out_layout),) diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index f100a35e5770..ad35c198bc77 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -30,6 +30,7 @@ from tvm import relay from tvm.relay import testing from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner import tvm.contrib.graph_runtime as runtime ################################################################# @@ -81,6 +82,7 @@ def get_network(name, batch_size): dtype = "float32" model_name = "resnet-18" log_file = "%s.log" % model_name +graph_opt_sch_file = "%s_graph_opt.log" % model_name # Set number of threads used for tuning based on the number of # physical CPU cores on your machine. @@ -157,6 +159,16 @@ def tune_kernels(tasks, autotvm.callback.progress_bar(n_trial, prefix=prefix), autotvm.callback.log_to_file(log_filename)]) +# Use graph tuner to achieve graph level optimal schedules +# Set use_DP=False if it takes too long to finish. +def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True): + target_op = [relay.nn.conv2d] + Tuner = DPTuner if use_DP else PBQPTuner + executor = Tuner(graph, {"data": dshape}, records, target_op, target) + executor.benchmark_layout_transform(min_exec_num=2000) + executor.run() + executor.write_opt_sch2record_file(opt_sch_file) + ######################################################################## # Finally, we launch tuning jobs and evaluate the end-to-end performance. @@ -171,9 +183,10 @@ def tune_and_evaluate(tuning_opt): # run tuning tasks print("Tuning...") tune_kernels(tasks, **tuning_opt) + tune_graph(net, data_shape, log_file, graph_opt_sch_file) - # compile kernels with history best records - with autotvm.apply_history_best(log_file): + # compile kernels with graph-level best records + with autotvm.apply_graph_best(graph_opt_sch_file): print("Compile...") with relay.build_config(opt_level=3): graph, lib, params = relay.build_module.build( From d0705c596b89f7052f1217bc13f69f1472d7bd90 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Fri, 31 May 2019 05:12:56 +0300 Subject: [PATCH 060/176] [Relay] Handle float16 constants & fix BatchNorm (#3260) --- src/relay/pass/pattern_util.h | 13 ++++++++++++- src/relay/pass/simplify_inference.cc | 8 ++++---- .../relay/test_pass_simplify_inference.py | 17 +++++++++-------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index b44bb682d317..b709f2846b34 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -27,6 +27,7 @@ #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ #define TVM_RELAY_PASS_PATTERN_UTIL_H_ +#include #include #include #include @@ -49,6 +50,9 @@ namespace relay { } else if (type == Float(32)) { \ typedef float DType; \ {__VA_ARGS__} \ + } else if (type == Float(16)) { \ + typedef uint16_t DType; \ + {__VA_ARGS__} \ } else if (type == Int(64)) { \ typedef int64_t DType; \ {__VA_ARGS__} \ @@ -204,7 +208,14 @@ template inline Constant MakeConstantScalar(DataType dtype, T value) { runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); TVM_DTYPE_DISPATCH(dtype, DType, { - *static_cast(arr->data) = value; + if (dtype == Float(16)) { + // convert to float16 + // storage is uint16_t + *static_cast(arr->data) = + __truncXfYf2__(static_cast(value)); + } else { + *static_cast(arr->data) = value; + } }) return ConstantNode::make(arr); } diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index cecebc5c04ed..8dab0c370853 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -36,11 +36,13 @@ Expr BatchNormToInferUnpack(const Attrs attrs, Expr moving_mean, Expr moving_var, Type tdata) { + auto ttype = tdata.as(); + CHECK(ttype); const auto param = attrs.as(); - Expr epsilon = MakeConstantScalar(Float(32), static_cast(param->epsilon)); + Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); Expr var_add_eps = Add(moving_var, epsilon); Expr sqrt_var = Sqrt(var_add_eps); - Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var); + Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var); if (param->scale) { scale = Multiply(scale, gamma); @@ -52,8 +54,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs, } int axis = param->axis; - auto ttype = tdata.as(); - CHECK(ttype); auto ndim = ttype->shape.size(); scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index 1387f276b290..aad1d9fc6cf5 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -17,12 +17,12 @@ from tvm import relay as rly from tvm.relay.ir_pass import simplify_inference, alpha_equal -def test_simplify_batchnorm(): +def test_simplify_batchnorm(dtype='float32'): def simple_bn(x, gamma, beta, moving_mean, moving_var, axis=1, epsilon=1e-5, shape=None): # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta - scale = rly.multiply(rly.const(1, 'float32') / - rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma) + scale = rly.multiply(rly.const(1, dtype) / + rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma) shift = rly.add( rly.multiply(rly.negative(moving_mean), scale), beta) num_newaxis = len(shape) - (axis + 1) @@ -33,8 +33,8 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var, def check(dim, axis, nstep): eps = 0.01 - ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32') - ttype2 = rly.TensorType((10,), 'float32') + ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype) + ttype2 = rly.TensorType((10,), dtype) x = rly.var("x", ttype1) beta = rly.var("beta", ttype2) gamma = rly.var("gamma", ttype2) @@ -43,10 +43,10 @@ def check(dim, axis, nstep): y1, y2 = x, x for _ in range(nstep): - y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'), + y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) y1 = rly.nn.dropout(y1) - y2 = simple_bn(y2 + rly.const(1, 'float32'), + y2 = simple_bn(y2 + rly.const(1, dtype), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis, shape=ttype1.shape) y1 = rly.ir_pass.infer_type(y1) @@ -60,4 +60,5 @@ def check(dim, axis, nstep): if __name__ == "__main__": - test_simplify_batchnorm() + test_simplify_batchnorm(dtype='float32') + test_simplify_batchnorm(dtype='float16') From 349116992bdb46dc9715f13eb74bf9114c5619e5 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Thu, 30 May 2019 21:11:25 -0700 Subject: [PATCH 061/176] [Bugfix] Fix a memory leak in OpManager (#3263) --- src/relay/ir/op.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 4a23f59f9637..b4303e7ac6b1 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,8 +54,8 @@ struct OpManager { std::vector frontend_funcs; // get singleton of the op manager static OpManager* Global() { - static OpManager inst; - return &inst; + static OpManager* inst = new OpManager(); + return inst; } }; From bd4ead263e577f74272ce79d052806dafec7fe80 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 30 May 2019 21:32:33 -0700 Subject: [PATCH 062/176] Jekyll (#3262) --- docker/Dockerfile.ci_jekyll | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 docker/Dockerfile.ci_jekyll diff --git a/docker/Dockerfile.ci_jekyll b/docker/Dockerfile.ci_jekyll new file mode 100644 index 000000000000..5d3cf86dd6f5 --- /dev/null +++ b/docker/Dockerfile.ci_jekyll @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# CI docker Jekyll env for building website +# tag: v0.50 +FROM ubuntu:16.04 + +RUN apt-get update && apt-get install -y sudo wget +RUN apt-get update && apt-get install -y ruby-full build-essential zlib1g-dev +RUN gem install jekyll bundler From d9f4f5db4778342128e8e493fcc4aa3ec35155bd Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 31 May 2019 01:29:54 -0700 Subject: [PATCH 063/176] [Relay][Hashing] Structural hash - incorporate the var type into its hash (#3267) Currently, the BindVar function does not take Var type into account. This causes two same graph structures with different var shapes to have same hash. Structural hash is used for keeping track of which operators we have already compiled. Because of this, two operators with different shapes end up pointing to same compiled code. The failure is encountered at runtime, where the expected input shape asserts are not met. --- src/relay/ir/hash.cc | 3 +++ tests/python/relay/test_pass_alpha_equal.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index c56c4ce17067..c57475476e58 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -219,6 +219,9 @@ class RelayHashHandler: size_t BindVar(const NodeRef& var) { size_t hash = std::hash()(var_counter++); CHECK_EQ(hash_map_.count(var), 0); + if (auto var_node = var.as()) { + hash = Combine(hash, TypeHash(var_node->type_annotation)); + } hash_map_[var] = hash; const auto* ty_param = var.as(); diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 478b433180b9..0e0036565363 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -594,7 +594,24 @@ def test_graph_equal(): # Check the difference in the text format. assert not alpha_equal(z0, z3) +def test_hash_unequal(): + x1 = relay.var("x1", shape=(10, 10), dtype="float32") + y1 = relay.var("y1", shape=(10, 10), dtype="float32") + func1 = relay.Function([x1, y1], relay.add(x1, y1)) + # func2 is exactly same structure with same variables shapes and dtypes + x2 = relay.var("x2", shape=(10, 10), dtype="float32") + y2 = relay.var("y2", shape=(10, 10), dtype="float32") + func2 = relay.Function([x2, y2], relay.add(x2, y2)) + + assert ir_pass.structural_hash(func1) == ir_pass.structural_hash(func2) + + # func3 is same as func1 but with different var shapes + x3 = relay.var("x3", shape=(20, 10), dtype="float32") + y3 = relay.var("y3", shape=(20, 10), dtype="float32") + func3 = relay.Function([x3, y3], relay.add(x3, y3)) + + assert not ir_pass.structural_hash(func1) == ir_pass.structural_hash(func3) if __name__ == "__main__": test_tensor_type_alpha_equal() @@ -617,3 +634,4 @@ def test_graph_equal(): test_op_alpha_equal() test_var_alpha_equal() test_graph_equal() + test_hash_unequal() From a49722c439418d3da27e3cf044d0ae80c87a6874 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Fri, 31 May 2019 12:21:36 -0700 Subject: [PATCH 064/176] Enable uTVM in Jenkinsfile (#3269) --- Jenkinsfile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Jenkinsfile b/Jenkinsfile index af5a2ce3eb42..ea468c24ed53 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -135,6 +135,7 @@ stage('Build') { echo set\\(USE_CUDNN ON\\) >> config.cmake echo set\\(USE_CUDA ON\\) >> config.cmake echo set\\(USE_OPENGL ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_LLVM llvm-config-6.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake @@ -157,6 +158,7 @@ stage('Build') { echo set\\(USE_OPENCL ON\\) >> config.cmake echo set\\(USE_ROCM ON\\) >> config.cmake echo set\\(USE_VULKAN ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER clang-6.0\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake @@ -174,6 +176,7 @@ stage('Build') { cd build cp ../cmake/config.cmake . echo set\\(USE_SORT ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake @@ -202,6 +205,7 @@ stage('Build') { cd build cp ../cmake/config.cmake . echo set\\(USE_SORT ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_RPC ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(USE_LLVM llvm-config-5.0\\) >> config.cmake From d1bb99f0505b55f569740f09332c7cfdbd46c7dd Mon Sep 17 00:00:00 2001 From: Hua Date: Fri, 31 May 2019 19:42:15 -0700 Subject: [PATCH 065/176] [Bugfix][VTA] PkgConfig cause crash in PYNQ board due to link library (#3257) * [Bugfix][VTA] PkgConfig cause crash in PYNQ board due to link library not exist. Symptom: When run vta_get_started.py with pynq board, host crash and complain "cannot find -lsds_lib" and "cannot find -l:libdma.so" Reproduce: At pynq board, delete the ./build/vta_config.json, then run rpc server. In host machine run vta_get_started.py, issue would reproduce. Analysis: This issue caused by 'PkgConfig' function still using pynq2.1 library which not exist in pynq2.4 anymore, when a "reconfig_runtime" logic of rpc_server.py get triggered , the compile would failed due to link library not exist. Solution: change the link library to libcma.so. * [Document Change][VTA] Change pynq version from 2.3 into 2.4. Issue: pynq 2.3 image not available anymore from pynq download page and pynq 2.4 is the current latest image which available in the said website, after verification, currently VTA work good with pynq 2.4 image, hence update related document from pynq 2.3 to 2.4. --- cmake/modules/VTA.cmake | 2 +- docs/vta/install.md | 2 +- vta/python/vta/pkg_config.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 1adb0aaf387a..1df6c6676fac 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -55,7 +55,7 @@ elseif(PYTHON) set_target_properties(vta PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif(APPLE) - # PYNQ rules for Pynq v2.3 + # PYNQ rules for Pynq v2.4 if(${VTA_TARGET} STREQUAL "pynq") find_library(__cma_lib NAMES cma PATH /usr/lib) target_link_libraries(vta ${__cma_lib}) diff --git a/docs/vta/install.md b/docs/vta/install.md index 8fa779a5d5b8..6c87b4edd288 100644 --- a/docs/vta/install.md +++ b/docs/vta/install.md @@ -84,7 +84,7 @@ This guide covers the following themes: Setup your Pynq board based on the [Pynq board getting started tutorial](http://pynq.readthedocs.io/en/latest/getting_started.html). You should follow the instructions up to and including the *Turning On the PYNQ-Z1* step (no need to pursue the tutorial beyond this point). -* Make sure that you've downloaded the latest Pynq image, [PYNQ-Z1 v2.3](http://www.pynq.io/board.html) (released October 3rd 2018), and have imaged your SD card with it (we recommend the free [Etcher](https://etcher.io/) program). +* Make sure that you've downloaded the latest Pynq image, [PYNQ-Z1 v2.4](http://www.pynq.io/board.html)(released February 22rd 2019), and have imaged your SD card with it (we recommend the free [Etcher](https://etcher.io/) program). * For this test setup, follow the ["Connect to a Computer"](http://pynq.readthedocs.io/en/latest/getting_started.html#connect-to-a-computer) Ethernet setup instructions. To be able to talk to the board, make sure to [assign your computer a static IP address](http://pynq.readthedocs.io/en/latest/appendix.html#assign-your-computer-a-static-ip) Once the board is powered on and connected to your development machine, try connecting to it to make sure you've properly set up your Pynq board: diff --git a/vta/python/vta/pkg_config.py b/vta/python/vta/pkg_config.py index 3b2824765010..2c30414ace1a 100644 --- a/vta/python/vta/pkg_config.py +++ b/vta/python/vta/pkg_config.py @@ -77,10 +77,9 @@ def __init__(self, cfg, proj_root): if self.target == "pynq": self.ldflags = [ "-L/usr/lib", - "-lsds_lib", "-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/", "-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/", - "-l:libdma.so"] + "-l:libcma.so"] else: self.ldflags = [] From 2581b69152d83cad835df1f071a36a34ac21006a Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Sat, 1 Jun 2019 00:53:18 -0700 Subject: [PATCH 066/176] [relay][heterogeneous] annotate using visitor (#3261) * annotate using visitor * retrigger CI --- src/relay/pass/device_annotation.cc | 11 +- tests/python/relay/test_pass_annotation.py | 126 ++++++++++++++------- 2 files changed, 94 insertions(+), 43 deletions(-) diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index fa656dbf489e..e2d07619cb0f 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator { } Expr VisitExpr_(const CallNode* call_node) final { - if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) { + if (IsOnDeviceNode(call_node)) { + return this->VisitExpr(call_node->args[0]); + } + + if (IsDeviceCopyNode(call_node)) { return ExprMutator::VisitExpr_(call_node); } @@ -358,6 +362,9 @@ class DeviceInfo { public: void Visit(const Expr& expr) { if (const auto* fn = expr.as()) { + for (const auto& param : fn->params) { + this->VisitExpr(param); + } this->VisitExpr(fn->body); } else { this->VisitExpr(expr); @@ -402,7 +409,7 @@ class DeviceInfo { } void VisitExpr_(const VarNode* vn) final { - post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); + post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); } void VisitExpr_(const LetNode* ln) final { diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 98cf0f15446e..ba2c249693b7 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -21,6 +21,7 @@ import tvm from tvm import relay from tvm.contrib import graph_runtime +from tvm.relay.expr_functor import ExprMutator def test_redundant_annotation(): @@ -34,11 +35,10 @@ def annotated(): add = relay.add(x, y) _add1 = relay.annotation.on_device(add, ctx2) _add2 = relay.annotation.on_device(add, ctx2) - sub = relay.subtract(add, z) + sub1 = relay.subtract(_add1, z) + sub2 = relay.subtract(_add2, z) - func = relay.Function([x, y, z], - relay.Tuple(tvm.convert([_add1, _add2, - sub]))) + func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) @@ -46,9 +46,11 @@ def annotated(): def expected(): add = relay.add(x, y) - copy_add_sub = relay.device_copy(add, ctx2, ctx1) - sub = relay.subtract(copy_add_sub, z) - func = relay.Function([x, y, z], sub) + copy_add_sub1 = relay.device_copy(add, ctx2, ctx1) + sub1 = relay.subtract(copy_add_sub1, z) + copy_add_sub2 = relay.device_copy(add, ctx2, ctx1) + sub2 = relay.subtract(copy_add_sub2, z) + func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) return func annotated_func = relay.ir_pass.infer_type(annotated()) @@ -66,10 +68,9 @@ def test_annotate_expr(): def annotated(): add = relay.add(x, y) _add = relay.annotation.on_device(add, ctx1) - sub = relay.subtract(add, z) + sub = relay.subtract(_add, z) _sub = relay.annotation.on_device(sub, ctx2) - expr = relay.Tuple([sub, _add, _sub]) - expr = relay.ir_pass.infer_type(expr) + expr = relay.ir_pass.infer_type(_sub) expr = relay.ir_pass.rewrite_annotated_ops(expr, ctx1.device_type) return expr @@ -95,12 +96,10 @@ def test_annotate_all(): def annotated(): add = relay.add(x, y) _add = relay.annotation.on_device(add, ctx2) - sub = relay.subtract(add, z) + sub = relay.subtract(_add, z) _sub = relay.annotation.on_device(sub, ctx2) - func = relay.Function([x, y, z], - relay.Tuple(tvm.convert([_add, _sub, - sub]))) + func = relay.Function([x, y, z], _sub) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) @@ -168,6 +167,34 @@ def test_conv_network(): dev1 = tvm.context(1) dev2 = tvm.context(2) + def original(): + conv2d_1 = relay.nn.conv2d( + data1, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + conv2d_2 = relay.nn.conv2d( + data2, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + add = relay.add(conv2d_1, conv2d_2) + conv2d_3 = relay.nn.conv2d( + add, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + + func = relay.Function([data1, data2, weight], conv2d_3) + func = relay.ir_pass.infer_type(func) + func = relay.ir_pass.rewrite_annotated_ops(func, + tvm.context(3).device_type) + return func + + def annotated(): conv2d_1 = relay.nn.conv2d( data1, @@ -183,25 +210,40 @@ def annotated(): kernel_size=(3, 3), padding=(1, 1)) _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2) - add = relay.add(conv2d_1, conv2d_2) + add = relay.add(_conv2d_1, _conv2d_2) _add = relay.annotation.on_device(add, dev1) conv2d_3 = relay.nn.conv2d( - add, + _add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2) - func = relay.Function([data1, data2, weight], - relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2, - _conv2d_3, _add, - conv2d_3]))) + func = relay.Function([data1, data2, weight], _conv2d_3) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, tvm.context(3).device_type) return func + class ScheduleConv2d(ExprMutator): + def __init__(self, device): + self.device = device + super().__init__() + + def visit_call(self, expr): + visit = super().visit_call(expr) + if expr.op == tvm.relay.op.get("nn.conv2d"): + return relay.annotation.on_device(visit, self.device) + else: + return visit + + def annotate_with_visitor(func): + sched = ScheduleConv2d(dev2) + func = sched.visit(func) + func = relay.ir_pass.rewrite_annotated_ops(func, dev1.device_type) + return func + def expected(): conv2d_1 = relay.nn.conv2d( data1, @@ -249,10 +291,19 @@ def check_storage_and_device_types(): assert len(set(device_types)) == 2 assert set(device_types) == {1, 2} - annotated_func = annotated() - expected_func = expected() - check_annotated_graph(annotated_func, expected_func) - check_storage_and_device_types() + def test_manual_annotation(): + annotated_func = annotated() + expected_func = expected() + check_annotated_graph(annotated_func, expected_func) + check_storage_and_device_types() + + def test_visitor_annotation(): + annotated_func = annotate_with_visitor(original()) + expected_func = expected() + check_annotated_graph(annotated_func, expected_func) + + test_manual_annotation() + test_visitor_annotation() def run_fusible_network(dev, tgt): @@ -321,12 +372,11 @@ def annotated(): sqrt = relay.sqrt(add) _sqrt = relay.annotation.on_device(sqrt, dev_ctx) log = relay.log(add) - subtract = relay.subtract(sqrt, log) + subtract = relay.subtract(_sqrt, log) exp = relay.exp(subtract) _exp = relay.annotation.on_device(exp, dev_ctx) - func = relay.Function([x, y], - relay.Tuple(tvm.convert([_sqrt, _exp, exp]))) + func = relay.Function([x, y], _exp) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) @@ -364,19 +414,16 @@ def test_fuse_all(device, tgt): def annotated(): add = relay.add(x, y) _add = relay.annotation.on_device(add, dev_ctx) - sqrt = relay.sqrt(add) + sqrt = relay.sqrt(_add) _sqrt = relay.annotation.on_device(sqrt, dev_ctx) - log = relay.log(add) + log = relay.log(_add) _log = relay.annotation.on_device(log, dev_ctx) - subtract = relay.subtract(sqrt, log) + subtract = relay.subtract(_sqrt, _log) _subtract = relay.annotation.on_device(subtract, dev_ctx) - exp = relay.exp(subtract) + exp = relay.exp(_subtract) _exp = relay.annotation.on_device(exp, dev_ctx) - func = relay.Function([x, y], - relay.Tuple(tvm.convert([_add, _sqrt, _log, - _subtract, _exp, - exp]))) + func = relay.Function([x, y], _exp) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) @@ -401,8 +448,7 @@ def annotated(): exp = relay.exp(subtract) _exp = relay.annotation.on_device(exp, cpu_ctx) - func = relay.Function([x, y], - relay.Tuple(tvm.convert([_exp, exp]))) + func = relay.Function([x, y], _exp) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) @@ -472,11 +518,9 @@ def annotated(): _add = relay.annotation.on_device(add, dev_ctx) mul = relay.multiply(c, d) _mul = relay.annotation.on_device(mul, cpu_ctx) - sub = relay.subtract(add, mul) + sub = relay.subtract(_add, _mul) _sub = relay.annotation.on_device(sub, dev_ctx) - func = relay.Function([a, b, c, d], - relay.Tuple(tvm.convert([_add, _mul, - _sub, sub]))) + func = relay.Function([a, b, c, d], _sub) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) From 61dda87a248cfa8721ef85bde53b757c50d1c19b Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Sat, 1 Jun 2019 11:16:16 -0700 Subject: [PATCH 067/176] Update tflite tutorial to use TFLite r1.13 schema (#3271) --- tutorials/frontend/from_tflite.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index 01669818bb99..f8cdd991c984 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -52,7 +52,7 @@ flatc --version # Get the TFLite schema. - wget https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/contrib/lite/schema/schema.fbs + wget https://raw.githubusercontent.com/tensorflow/tensorflow/r1.13/tensorflow/lite/schema/schema.fbs # Generate TFLite package. flatc --python schema.fbs @@ -144,7 +144,7 @@ def extract(path): # target x86 CPU target = "llvm" -with relay.transform.build_config(opt_level=3): +with relay.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) ###################################################################### @@ -180,11 +180,9 @@ def extract(path): label_file = "labels_mobilenet_quant_v1_224.txt" label_path = download_testdata(label_file_url, label_file, module='data') -# map id to 1001 classes -labels = dict() +# list of 1001 classes with open(label_path) as f: - for id, line in enumerate(f): - labels[id] = line + labels = f.readlines() # convert result to 1D data predictions = np.squeeze(tvm_output) From 92a61c3adb73d905120487a9c728950202184389 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Mon, 3 Jun 2019 18:52:31 +0300 Subject: [PATCH 068/176] [ARITH] Bugfix: check arg positiveness for mod rules (#3279) --- src/arithmetic/rewrite_simplify.cc | 4 +++- .../unittest/test_arith_rewrite_simplify.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 00198d9b140a..ee3265618876 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -634,10 +634,12 @@ Mutate_(const Mod* op, const Expr& self) { TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2, c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual((x * c1).Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2, c2.Eval()->value > 0 && + c1.Eval()->value >= 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0)); @@ -645,7 +647,7 @@ Mutate_(const Mod* op, const Expr& self) { c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual((y * c1).Eval(), 0)); // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 1b03253c9a0f..ee113e101cce 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -302,9 +302,11 @@ def test_div_index_simplify(): def test_mod_index_simplify(): ck = RewriteChecker() - x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") + x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z") ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) + ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True) + ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True) ck.verify(x * 10 % 2, 0) ck.verify((x * 10 + y) % 2, y % 2) @@ -317,6 +319,19 @@ def test_mod_index_simplify(): ck.verify((x + y * 10) % -2, x % 2) ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1) + ck.verify(x * (-10) % 2, 0) + ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2) + ck.verify((x + (-10)) % 2, (x + (-10)) % 2) + ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2) + ck.verify(x * (-10) % -2, 0) + + ck.verify(nx * 10 % 2, 0) + ck.verify((nx * (-10) + y) % 2, y % 2) + ck.verify((x + ny * (-10)) % 2, x % 2) + ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1) + ck.verify(nx * 10 % -2, 0) + ck.verify((nx * (-10) + y) % -2, y % 2) + ck.verify((x + ny * (-10)) % -2, x % 2) def test_min_index_simplify(): ck = RewriteChecker() From 3919e016cf55136db25a6bb4560be1dbe931219f Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 3 Jun 2019 10:40:38 -0700 Subject: [PATCH 069/176] [RELAY][TRANSFORM] Migrate buildmodule to transform (#3251) --- include/tvm/relay/module.h | 26 +- include/tvm/relay/pass.h | 20 ++ include/tvm/relay/transform.h | 90 ++++- python/tvm/relay/build_module.py | 94 +----- python/tvm/relay/transform.py | 199 +++++++++++ src/relay/backend/build_module.cc | 370 +++++++-------------- src/relay/pass/alter_op_layout.cc | 27 +- src/relay/pass/canonicalize_ops.cc | 17 + src/relay/pass/combine_parallel_conv2d.cc | 17 + src/relay/pass/dead_code.cc | 5 +- src/relay/pass/device_annotation.cc | 8 +- src/relay/pass/eliminate_common_subexpr.cc | 17 + src/relay/pass/fold_constant.cc | 8 +- src/relay/pass/fold_scale_axis.cc | 42 ++- src/relay/pass/forward_rewrite.cc | 4 +- src/relay/pass/fuse_ops.cc | 7 +- src/relay/pass/partial_eval.cc | 9 +- src/relay/pass/pass_manager.cc | 166 +++++---- src/relay/pass/simplify_inference.cc | 17 + src/relay/pass/to_a_normal_form.cc | 5 +- src/relay/pass/to_graph_normal_form.cc | 5 +- src/relay/pass/type_infer.cc | 19 ++ tests/cpp/relay_transform_sequential.cc | 111 +++++++ tests/python/relay/test_pass_manager.py | 51 ++- 24 files changed, 879 insertions(+), 455 deletions(-) create mode 100644 tests/cpp/relay_transform_sequential.cc diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 6441fb3f5b9c..3966a6258a20 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -87,14 +87,14 @@ class ModuleNode : public RelayNode { * \param update Controls whether you can replace a definition in the * environment. */ - void Add(const GlobalVar& var, const Function& func, bool update = false); + TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. * \param type The type definition. */ - void AddDef(const GlobalTypeVar& var, const TypeData& type); + TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type); /*! * \brief Add a function to the global environment. @@ -103,69 +103,69 @@ class ModuleNode : public RelayNode { * * It does not do type inference as Add does. */ - void AddUnchecked(const GlobalVar& var, const Function& func); + TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); /*! * \brief Update a function in the global environment. * \param var The name of the global function to update. * \param func The new function. */ - void Update(const GlobalVar& var, const Function& func); + TVM_DLL void Update(const GlobalVar& var, const Function& func); /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. */ - void Remove(const GlobalVar& var); + TVM_DLL void Remove(const GlobalVar& var); /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalVar GetGlobalVar(const std::string& str); + TVM_DLL GlobalVar GetGlobalVar(const std::string& str); /*! * \brief Look up a global function by its name. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalTypeVar GetGlobalTypeVar(const std::string& str); + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ - Function Lookup(const GlobalVar& var); + TVM_DLL Function Lookup(const GlobalVar& var); /*! * \brief Lookup a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ - Function Lookup(const std::string& name); + TVM_DLL Function Lookup(const std::string& name); /*! * \brief Lookup a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ - TypeData LookupDef(const GlobalTypeVar& var); + TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); /*! * \brief Lookup a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ - TypeData LookupDef(const std::string& var); + TVM_DLL TypeData LookupDef(const std::string& var); /*! * \brief Update the functions inside this environment by * functions in another environment. * \param other The other environment. */ - void Update(const Module& other); + TVM_DLL void Update(const Module& other); /*! \brief Construct a module from a standalone expression. * @@ -177,7 +177,7 @@ class ModuleNode : public RelayNode { * * \returns A module with expr set as the entry point. */ - static Module FromExpr( + TVM_DLL static Module FromExpr( const Expr& expr, const tvm::Map& global_funcs = {}); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 67cc5df82407..81587339f2ad 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -358,6 +358,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); +/*! + * \brief Collect the device anntation operators. + * + * \param expr The expression. + * + * \return The annotated expression to device type mapping for annotation ops. + */ +TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); + /*! * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). * @@ -403,6 +412,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); */ TVM_DLL Expr PartialEval(const Expr& e); +/*! + * \brief Bind the free variables to a Relay expression. + * + * \param expr The expression. + * \param bind_map The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1c1b60813b78..793bc981ea61 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -58,9 +58,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -292,9 +294,9 @@ class Sequential : public Pass { * \param passes The passes to apply. * \param pass_info The pass metadata. */ - TVM_DLL Sequential(tvm::Array passes, - PassInfo pass_info); -/*! + TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); + + /*! * \brief The constructor of `Sequential`. * * \param passes The passes to apply. @@ -311,7 +313,6 @@ class Sequential : public Pass { using ContainerType = Sequential; }; - /* * \brief Create a module pass. * @@ -339,7 +340,7 @@ Pass CreateModulePass( * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, Module, PassContext)>& pass_func, + Function(Function, Module, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array& required); @@ -451,6 +452,85 @@ TVM_DLL Pass ToGraphNormalForm(); */ TVM_DLL Pass PartialEval(); +/*! + * \brief Simplify certain operators during inference. For example, batch norm + * will be unpacked into a number of simplified operators. + * + * \return The Pass. + */ +TVM_DLL Pass SimplifyInference(); + +/*! + * \brief Infer the type of an expression. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \return The pass. + */ +TVM_DLL Pass InferType(); + +/*! + * \brief Search and eliminate common subexpression. For example, if there are + * two expressions evaluated to an identical value, a single variable is created + * and these two expressions are replaced by this variable. + * + * \param fskip The callback argument that allows to skip certain expressions. + * + * \return The pass. + */ +TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); + +/*! + * \brief Combine parallel 2d convolutions into a single convolution if the + * number of branches of this conv2d operator is not less than + * `min_num_branch`. + * + * \param min_num_branches The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); + +/*! + * \brief Backward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass BackwardFoldScaleAxis(); + +/*! + * \brief Forward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass ForwardFoldScaleAxis(); + +/*! + * \brief A sequential pass that executes ForwardFoldScaleAxis and + * BackwardFoldScaleAxis passes. + * + * \return The pass. + */ +TVM_DLL Pass FoldScaleAxis(); + +/*! + * \brief Canonicalize some operators to the simplified operators. For example, + * bias_add can be canonicalized to expand_dims and broadcast_add. + * + * \return The pass. + */ +TVM_DLL Pass CanonicalizeOps(); + +/*! + * \brief Alternate the layouts of operators or replace primitive operators + * with other expressions. + * + * \return The pass. + */ +TVM_DLL Pass AlterOpLayout(); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 6cee393d5f91..8f9b0481a22c 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -20,7 +20,6 @@ """ import numpy as np -from tvm._ffi.runtime_ctypes import TVMContext from tvm import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt @@ -28,7 +27,6 @@ from . import ir_pass from . import ty as _ty from . import expr as _expr -from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -61,10 +59,6 @@ def __init__(self): self._get_graph_json = self.mod["get_graph_json"] self._get_module = self.mod["get_module"] self._build = self.mod["build"] - self._add_pass = self.mod["add_pass"] - self._disable_pass = self.mod["disable_pass"] - self._set_opt_level = self.mod["set_opt_level"] - self._set_fallback_device = self.mod["set_fallback_device"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -106,8 +100,9 @@ def build(self, func, target=None, target_host=None, params=None): """ target = _update_target(target) - # Setup the build configurations passed in through `with build_config`. - self._setup_build_config(params) + # Setup the params. + if params: + self._set_params(params) # Build the function self._build(func, target, target_host) # Get artifacts @@ -117,41 +112,6 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params - def _setup_build_config(self, params): - cfg = _transform.PassContext.current() - - # Set opt_level. - self.set_opt_level(cfg.opt_level) - - # Set fallback device if it is available. - if cfg.fallback_device: - self.set_fallback_device(cfg.fallback_device) - - # Add required passes. - if cfg.required_pass: - passes = set() - if isinstance(cfg.required_pass, (list, tuple, set)): - passes = set(cfg.required_pass) - else: - raise TypeError("add_pass must be list, tuple, or set, but " + - "got {}".format(type(cfg.required_pass))) - for pass_name in passes: - self.add_pass(pass_name) - - # Add disabled passes. - if cfg.disabled_pass: - passes = set() - if isinstance(cfg.disabled_pass, (list, tuple, set)): - passes = set(cfg.disabled_pass) - else: - raise TypeError("disable_pass must be list, tuple, or set, " + - "but got {}".format(type(cfg.disabled_pass))) - for pass_name in passes: - self.disable_pass(pass_name) - - if params: - self._set_params(params) - def _set_params(self, params): inputs = {} for name, param in params.items(): @@ -160,28 +120,6 @@ def _set_params(self, params): inputs[name] = _expr.const(param) self._set_params_func(inputs) - def add_pass(self, pass_name): - """Add a pass to the pass list. - - Parameters - ---------- - pass_name : str - The name of the pass that will be added to the list of passes used - for optimizations. - """ - self._add_pass(pass_name) - - def disable_pass(self, pass_name): - """Add a pass to the disabled pass list. - - Parameters - ---------- - pass_name : str - The name of a pass. This pass will be added to the list of passes - that are disabled during optimization. - """ - self._disable_pass(pass_name) - def get_json(self): """Return the json file of the built program.""" return self._get_graph_json() @@ -198,32 +136,6 @@ def get_params(self): ret[key] = value.data return ret - def set_opt_level(self, level): - """Set the optimization level. - - Parameters - ---------- - level : int - The optimization level for build. - """ - self._set_opt_level(level) - - def set_fallback_device(self, fallback_device): - """Set the fallback device for heterogeneous execution. - - Parameters - ---------- - fallback_device : str or tvm.TVMContext - The fallback device used for heterogeneous execution. - """ - if isinstance(fallback_device, (int, str)): - fallback_device = _nd.context(fallback_device) - if not isinstance(fallback_device, TVMContext): - raise TypeError("fallback_device is expected to be str, int, or " + - "TVMContext but received: {}".format(type(fallback_device))) - - self._set_fallback_device(fallback_device.device_type) - def build(func, target=None, target_host=None, params=None): """Helper function that builds a Relay function to run on TVM graph diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index a7887c630c76..38079b010e7d 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck +# pylint: disable=invalid-name """ This file contains the pass manager for Relay which exposes different granularity of interfaces for users to implement and use passes more @@ -394,3 +395,201 @@ def create_function_pass(pass_func): if pass_func: return create_function_pass(pass_func) return create_function_pass + + +def InferType(): + """Infer the type of an expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered type inference pass. + """ + return _transform.InferType() + + +def FoldScaleAxis(): + """Fold the scaling of axis into weights of conv2d/dense. This pass will + invoke both forward and backward scale folding. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to fold expressions. + + Note + ---- + Internally, we will call backward_fold_scale_axis before using + forward_fold_scale_axis. As backward folding targets common conv-bn + pattern. + """ + return _transform.FoldScaleAxis() + + +def SimplifyInference(): + """Simplify the data-flow graph for inference phase. An simplified expression + which is semantically equal to the input expression will be returned. + + Returns + ------- + ret: tvm.relay.Pass + The registered to perform operator simplification. + """ + return _transform.SimplifyInference() + + +def CanonicalizeOps(): + """ Canonicalize special operators to basic operators. + This can simplify followed analysis. (e.g. expanding bias_add to + expand_dims and broadcast_add.) + + Returns + ------- + ret: tvm.relay.Pass + The registered pass performing the canonicalization. + """ + return _transform.CanonicalizeOps() + + +def DeadCodeElimination(): + """ Remove expressions which does not effect the program result (dead code). + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that eliminates the dead code in a Relay program. + """ + return _transform.DeadCodeElimination() + + +def FoldConstant(): + """Fold the constant expression in expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for constant folding. + """ + return _transform.FoldConstant() + + +def FuseOps(fuse_opt_level=-1): + """Fuse operators in an expr to a larger operator according to some rules. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for operator fusion. + """ + return _transform.FuseOps(fuse_opt_level) + + +def CombineParallelConv2D(min_num_branches=3): + """Combine multiple conv2d operators into one. + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that combines parallel conv2d operators. + """ + return _transform.CombineParallelConv2D(min_num_branches) + + +def AlterOpLayout(): + """Alternate the layouts of operators or replace primitive operators with + other expressions. + This pass can be used for computing convolution in custom layouts or + other general weight pre-transformation. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that alters the layout of operators. + """ + return _transform.AlterOpLayout() + + +def RewriteAnnotatedOps(fallback_device): + """Rewrite the annotated program where annotation operators, e.g. + `on_deivce`, mark which device an expression should be scheduled to. + This pass helps heterogeneous execution where different operators may need + to be allocated on various devices. + + Parameters + ---------- + fallback_device : int + The fallback device type. It is also used as the default device for + operators with no annotated device. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that rewrites an expression with annotated + `on_device` operators. + """ + return _transform.RewriteDeviceAnnotation(fallback_device) + + +def ToANormalForm(): + """Turn Graph Normal Form expression into A Normal Form Expression. + The scope of the root expression is the global scope. + The scope of any non root expression is the least common ancestor of all it's scope. + Values are ordered by post-DFS order in each scope. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that transforms an expression into A Normal Form. + """ + return _transform.ToANormalForm() + + +def ToGraphNormalForm(): + """Turn A Normal Form expression into Graph Normal Form expression + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that transforms an expression into Graph Normal Form. + """ + return _transform.ToGraphNormalForm() + + +def EliminateCommonSubexpr(fskip=None): + """Eliminate common subexpressions. + + Parameters + ---------- + fskip: Callable + The callback function that decides whether an expression should be + skipped. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that eliminates common subexpressions. + """ + return _transform.EliminateCommonSubexpr(fskip) + + +def PartialEvaluate(): + """Evaluate the static fragment of the code. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that performs partial evaluation on an expression. + """ + return _transform.PartialEvaluate() diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 57dc256ef6b7..e0014e919089 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -23,12 +23,8 @@ */ #include #include -#include #include -#include -#include -#include -#include +#include #include #include "utils.h" @@ -38,39 +34,7 @@ namespace relay { namespace backend { using TargetsMap = Map; - -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - * - */ -struct OptPassLevel { - static const std::unordered_map _data; - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - auto it = _data.find(key); - if (it == _data.end()) { - return -1; - } - return it->second; - } -}; - -const std::unordered_map OptPassLevel::_data = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 4}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} -}; +using namespace tvm::relay::transform; /*! * \brief Output of building module @@ -82,27 +46,6 @@ struct BuildOutput { std::unordered_map params; }; -/*! - * \brief Relay building config - * - */ -struct RelayBuildConfig { - int opt_level{2}; - int fallback_device{static_cast(kDLCPU)}; - std::unordered_set enabled_pass; - std::unordered_set disabled_pass; - OptPassLevel OPT_PASS_LEVEL; - inline bool pass_enabled(const std::string& pass_name) const { - if (disabled_pass.count(pass_name)) { - return false; - } - if (enabled_pass.count(pass_name)) { - return true; - } - return opt_level >= OPT_PASS_LEVEL[pass_name]; - } -}; - /*! * \brief GraphCodegen module wrapper * @@ -156,18 +99,6 @@ struct GraphCodegen { } }; -template -R CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - -template -Function CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - /*! * \brief Relay build module * @@ -203,28 +134,6 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); - } else if (name == "set_opt_level") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int level = args[0]; - this->SetOptLevel(level); - }); - } else if (name == "set_fallback_device") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int dev = args[0]; - this->SetFallBackDev(dev); - }); - } else if (name == "add_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->AddPass(pass_name); - }); - } else if (name == "disable_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->DisablePass(pass_name); - }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -246,30 +155,7 @@ class RelayBuildModule : public runtime::ModuleNode { const std::string& GetGraphJSON() { return ret_.graph_json; } - /*! - * \brief Add extra pass into build cfg - * - * \param pass_name name of pass - */ - void AddPass(const std::string& pass_name) { - cfg_.enabled_pass.insert(pass_name); - } - /*! - * \brief Disable a specific pass in cfg - * - * \param pass_name name of pass - */ - void DisablePass(const std::string& pass_name) { - cfg_.disabled_pass.insert(pass_name); - } - /*! - * \brief Set the Fallback device - * - * \param device name - */ - void SetFallBackDev(int dev) { - cfg_.fallback_device = dev; - } + /*! * \brief Get the Module object * @@ -315,15 +201,6 @@ class RelayBuildModule : public runtime::ModuleNode { params_[name] = data_in; } - /*! - * \brief Set the optimization level - * - * \param level - */ - void SetOptLevel(char level) { - cfg_.opt_level = level; - } - /*! * \brief type key * @@ -345,7 +222,7 @@ class RelayBuildModule : public runtime::ModuleNode { const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; - BuildRelay(func, cfg_, params_); + BuildRelay(func, params_); } protected: @@ -378,85 +255,81 @@ class RelayBuildModule : public runtime::ModuleNode { if (repeat_var.count(arg)) { LOG(FATAL) << "Multiple args in the function have name " << kv.first; } - auto e = CallPackedFunc("relay._make.Constant", kv.second); - bind_dict[arg] = e; + bind_dict[arg] = ConstantNode::make(kv.second); } - return CallPackedFunc("relay._expr.Bind", func, tvm::Map(bind_dict)); + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) + << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; } /*! - * \brief Optimize Relay function + * \brief Optimize a Relay module. * - * \param func Input function - * \param target target device - * \param cfg Relay build config - * \param params params dict - * \return relay::Function + * \param relay_module The input Relay module where optmization will be + * applied on. + * \param targets The device type to `Target` mapping. + * \param params The param name to value mapping. + * + * \return relay::Module The updated Relay module after optimization. */ - relay::Function Optimize(relay::Function func, - const TargetsMap& targets, - const RelayBuildConfig& cfg, - const std::unordered_map& params) { - if (params.size()) { - func = BindParamsByName(func, params); - } - if (cfg.pass_enabled("SimplifyInference")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.simplify_inference", func); - } - if (cfg.pass_enabled("EliminateCommonSubexpr")) { - auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } + relay::Module Optimize( + relay::Module relay_module, + const TargetsMap& targets, + const std::unordered_map& params) { + Array pass_seqs; + pass_seqs.push_back(transform::SimplifyInference()); + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; } } - *rv = false; - }); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip); - } - if (cfg.pass_enabled("CombineParallelConv2D")) { - const int min_num_branches = 3; - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches); - } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("FoldScaleAxis")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("CanonicalizeOps")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func); + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + pass_seqs.push_back(transform::AlterOpLayout()); } - if (cfg.pass_enabled("AlterOpLayout")) { - if (targets.size() == 1) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - for (const auto& kv : targets) { - With tctx(kv.second); - func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); - } - } else { - LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" - << " execution yet."; + pass_seqs.push_back(transform::FoldConstant()); + + // Create a sequential pass and perform optimizations. + transform::Pass seq = transform::Sequential(pass_seqs); + if (targets.size() == 1) { + for (const auto& kv : targets) { + With tctx(kv.second); + relay_module = seq(relay_module); } + } else { + relay_module = seq(relay_module); } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + + // Handle heterogeneous compilation. + transform::PassContext pass_ctx = PassContext::Current(); + if (targets_.size() > 1) { + relay_module = + RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); } - return func; + + // Fuse the operations if it is needed. + relay_module = transform::FuseOps()(relay_module); + relay_module = transform::InferType()(relay_module); + + return relay_module; } /*! @@ -470,54 +343,58 @@ class RelayBuildModule : public runtime::ModuleNode { if (name == "gpu") return Target::Create("cuda"); return Target::Create(name); } + /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. * Meanwhile, a CPU device type and "llvm" pair will be added to the target * dictionary in this case. * - * \param targets dictionary - * \param cfg - * \return Map + * \param fallback_device The fallback device for heterogeneous execution. */ - TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets, - const RelayBuildConfig& cfg) { - TargetsMap device_target = targets; + void UpdateHeterogeneousInputs(int fallback_device) { std::unordered_map tmp_map; - for (const auto& kv : targets) { + for (const auto& kv : targets_) { tmp_map[kv.first->value] = kv.second; } - if (tmp_map.count(cfg.fallback_device) == 0) { - device_target.Set( - cfg.fallback_device, - CreateDefaultTarget(cfg.fallback_device)); + if (tmp_map.count(fallback_device) == 0) { + targets_.Set(fallback_device, CreateDefaultTarget(fallback_device)); } - return device_target; } + /*! * \brief Execute the device annotation passes to update the input program and * target information. * - * \param func - * \param cfg - * \param targets_map_ptr - * \return Function + * \param relay_module The input Relay module. + * \param fallback_device The fallback device for heterogeneous execution. + * + * \return updated_module The updated module after device annotation. */ - Function RunDeviceAnnotationPass(Function func, - const RelayBuildConfig& cfg, - TargetsMap* targets_map_ptr) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, - cfg.fallback_device); - auto device_map = CallPackedFunc >( - "relay._ir_pass.CollectDeviceInfo", func, nullptr); - if (device_map.size() == 0) { - auto annotation_map = CallPackedFunc >( - "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); - if (annotation_map.size() == 0) { - targets_map_ptr->Set( - 0, CreateDefaultTarget(cfg.fallback_device)); + relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, + int fallback_device) { + UpdateHeterogeneousInputs(fallback_device); + auto rewrite = transform::RewriteAnnotatedOps(fallback_device); + auto updated_module = rewrite(relay_module); + CHECK(updated_module.defined()); + + tvm::Map device_map; + for (const auto& it : updated_module->functions) { + device_map = relay::CollectDeviceInfo(it.second); + if (!device_map.empty()) break; + } + + if (device_map.empty()) { + tvm::Map annotation_map; + for (const auto& it : relay_module->functions) { + annotation_map = relay::CollectDeviceAnnotationOps(it.second); + if (!annotation_map.empty()) break; + } + // None op is annotated but they are fallen back to the default device. + if (annotation_map.empty()) { + targets_.Set(0, CreateDefaultTarget(fallback_device)); } else { + // All ops are annotated to the same device type. int64_t dev_type = -1; for (auto kv : annotation_map) { dev_type = kv.second->value; @@ -531,47 +408,42 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set(0, CreateDefaultTarget(dev_type)); + targets_.Set(0, CreateDefaultTarget(dev_type)); } } - return func; + return updated_module; } /*! * \brief Build relay function to runtime module * * \param func Relay Function - * \param cfg Relay build config * \param params parameters */ - void BuildRelay(Function func, - const RelayBuildConfig& cfg, - const std::unordered_map ¶ms) { - // convert - tvm_cfg_ = BuildConfig::Create(); - TargetsMap device_target; - if (targets_.size() > 1) { - device_target = UpdateHeterogeneousInputs(targets_, cfg); - } else { - device_target = targets_; - } - func = Optimize(func, targets_, cfg, params); - if (device_target.size() > 1) { - func = RunDeviceAnnotationPass(func, cfg, &device_target); + void BuildRelay( + Function func, + const std::unordered_map& params) { + if (params.size()) { + func = BindParamsByName(func, params); } - // TODO(@jroesch): use the passes directly. - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level, nullptr); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + // Perform Module->Module optimizations. + relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay_module = Optimize(relay_module, targets_, params); + CHECK(relay_module.defined()); + // Get the updated function. + func = relay_module->Lookup(relay_module->entry_func->name_hint); + + // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Init(nullptr, targets_); graph_codegen_->Codegen(func); ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_); + ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, + BuildConfig::Current()); } protected: @@ -580,14 +452,10 @@ class RelayBuildModule : public runtime::ModuleNode { TargetsMap targets_; /*! \brief target host device */ tvm::Target target_host_; - /*! \brief frontend optimization configure */ - RelayBuildConfig cfg_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; - /*! \brief tvm building cfg */ - BuildConfig tvm_cfg_; }; runtime::Module RelayBuildCreate() { diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index f51c201d0b2a..d623393049a6 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -338,17 +339,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // Limiations: // 1. the altered op should have the same number of arguments as the previous one // 2. do not support nested tuple arguments -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body([](TVMArgs args, TVMRetValue *ret) { +Expr AlterOpLayout(const Expr& expr) { TransformMemorizer transformMemorizer(make_node()); auto fcontext = [&](const Call& call) -> NodeRef{ return transformMemorizer; }; - *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext); -}); + return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); +} + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body_typed(AlterOpLayout); } // namespace alter_op_layout +namespace transform { + +Pass AlterOpLayout() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.AlterOpLayout") +.set_body_typed(AlterOpLayout); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 9a4602750195..ff9e2304a3bc 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "pattern_util.h" namespace tvm { @@ -63,5 +64,21 @@ Expr CanonicalizeOps(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") .set_body_typed(CanonicalizeOps); +namespace transform { + +Pass CanonicalizeOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CanonicalizeOps") +.set_body_typed(CanonicalizeOps); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 7e76322d5a2a..c95c1ddf8e16 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include "./expr_subst.h" @@ -357,5 +358,21 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body_typed(CombineParallelConv2D); +namespace transform { + +Pass CombineParallelConv2D(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelConv2D") +.set_body_typed(CombineParallelConv2D); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index dd1ed6240cab..be6774564806 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -158,9 +158,12 @@ Pass DeadCodeElimination() { [=](Function f, Module m, PassContext pc) { return Downcast(DeadCodeElimination(f)); }; - return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); + return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } +TVM_REGISTER_API("relay._transform.DeadCodeElimination") +.set_body_typed(DeadCodeElimination); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index e2d07619cb0f..02d6d9e1fefb 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -564,11 +565,14 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, Module m, PassContext pc) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {}); + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", + {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") +.set_body_typed(RewriteAnnotatedOps); + } // namespace transform } // namespace relay } // namespace tvm - diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index f8432f671855..883681adcaf4 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -29,6 +29,7 @@ */ #include #include +#include #include #include "./pattern_util.h" @@ -87,5 +88,21 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") .set_body_typed(EliminateCommonSubexpr); +namespace transform { + +Pass EliminateCommonSubexpr(PackedFunc fskip) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") +.set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 286392ab5d3f..815407038b08 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -220,11 +221,14 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(FoldConstant(f)); + return Downcast(FoldConstant(f)); }; - return CreateFunctionPass(pass_func, 1, "fold_constant", {}); + return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } +TVM_REGISTER_API("relay._transform.FoldConstant") +.set_body_typed(FoldConstant); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index c738e3e3b731..53089807ace5 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "pattern_util.h" #include "pass_util.h" @@ -530,7 +531,7 @@ RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); -Expr ForwardFoldScaleAxis(Expr data) { +Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ auto it = message.find(call.get()); @@ -942,7 +943,7 @@ RELAY_REGISTER_OP("nn.conv2d") RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); -Expr BackwardFoldScaleAxis(Expr data) { +Expr BackwardFoldScaleAxis(const Expr& data) { return make_node()->Fold(data); } @@ -950,5 +951,42 @@ TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") .set_body_typed(BackwardFoldScaleAxis); } // namespace fold_scale_axis + +namespace transform { + +Pass ForwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); +} + +Pass BackwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); +} + +Pass FoldScaleAxis() { + // FoldScaleAxis pass contains the following three passes. Therefore, we can + // register it as a sequential pass. + Pass pass = Sequential( + {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, + "FoldScaleAxis"); + return pass; +} + +TVM_REGISTER_API("relay._transform.FoldScaleAxis") +.set_body_typed(FoldScaleAxis); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 2a3aa1612418..8ad61270e33a 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -220,7 +220,7 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {}); } Pass ForwardRewrite(const FForwardRewrite& rewrite_func, @@ -233,7 +233,7 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {}); } } // namespace transform diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9277689075c2..9f940e54953b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "./pattern_util.h" #include "../../common/arena.h" @@ -973,9 +974,13 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "fuse_ops", {}); + return CreateFunctionPass(pass_func, 1, "FuseOps", + {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.FuseOps") +.set_body_typed(FuseOps); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 3f42c6fce4b2..71ba7cd11cd5 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -797,9 +797,7 @@ Expr PartialEval(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.partial_evaluate") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PartialEval(args[0]); - }); +.set_body_typed(PartialEval); namespace transform { @@ -808,9 +806,12 @@ Pass PartialEval() { [=](Function f, Module m, PassContext pc) { return Downcast(PartialEval(f)); }; - return CreateFunctionPass(pass_func, 1, "partial_eval", {}); + return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); } +TVM_REGISTER_API("relay._transform.PartialEvaluate") +.set_body_typed(PartialEval); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a9c671aa163a..13e908d28f7a 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -37,42 +37,46 @@ namespace transform { using tvm::IRPrinter; -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - */ -class OptPassLevel { - public: - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - const auto data = CreateMap(); - auto it = data.find(key); - if (it == data.end()) { - return -1; - } - return it->second; +namespace { + +// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be +// handled because we need to register the pass for Python invocation anyway. +Pass GetPass(const std::string& pass_name) { + if (pass_name == "InferType") { + return InferType(); + } else if (pass_name == "AlterOpLayout") { + return AlterOpLayout(); + } else if (pass_name == "CanonicalizeOps") { + return CanonicalizeOps(); + } else if (pass_name == "CombineParallelConv2d") { + return CombineParallelConv2D(); + } else if (pass_name == "DeadCodeElimination") { + return DeadCodeElimination(); + } else if (pass_name == "EliminateCommonSubexpr") { + return DeadCodeElimination(); + } else if (pass_name == "FoldConstant") { + return FoldConstant(); + } else if (pass_name == "BackwardFoldScaleAxis") { + return FoldScaleAxis(); + } else if (pass_name == "ForwardFoldScaleAxis") { + return FoldScaleAxis(); + } else if (pass_name == "FoldScaleAxis") { + return FoldScaleAxis(); + } else if (pass_name == "PartialEvaluate") { + return SimplifyInference(); + } else if (pass_name == "SimplifyInference") { + return SimplifyInference(); + } else if (pass_name == "ToANormalForm") { + return ToANormalForm(); + } else if (pass_name == "ToGraphNormalForm") { + return ToGraphNormalForm(); + } else { + LOG(FATAL) << pass_name << " has not been registered yet." << "\n"; + return Pass(nullptr); } +} - private: - static const std::unordered_map CreateMap() { - const std::unordered_map m = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} - }; - return m; - } -}; +} // namespace struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -246,12 +250,6 @@ class SequentialNode : public PassNode { /* \brief The pass meta data.*/ PassInfo pass_info; - /*! - * \brief A helper struct to get the optimization pass name to opt level - * mapping. - */ - OptPassLevel opt_pass_level; - /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -300,7 +298,7 @@ class SequentialNode : public PassNode { const Array& disabled) const; std::unordered_set RequiredPasses( - const Array& disabled) const; + const Array& required) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -338,14 +336,25 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -// TODO(zhiics) Check and handle the required passes. Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level << "\n"; + CHECK(mod.defined()); - auto updated_mod = pass_func(mod, pass_ctx); + Module updated_mod = mod; + // Execute the required passes in a DFS way. + // TODO(zhiics) We may need to pass validation to detect the cyclic + // dependency. + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + auto pass = GetPass(name->value); + updated_mod = pass(updated_mod, pass_ctx); + } + + updated_mod = pass_func(updated_mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } @@ -365,12 +374,26 @@ Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); CHECK(mod.defined()); - Module new_mod = ModuleNode::make({}, mod->type_definitions); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level << "\n"; + + Module updated_mod = mod; + // Execute the required passes in a DFS way. + // TODO(zhiics) We may need to pass validation to detect the cyclic + // dependency. + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + auto pass = GetPass(name->value); + updated_mod = pass(updated_mod, pass_ctx); + } + + Module new_mod = ModuleNode::make({}, mod->type_definitions); // Execute the pass function and return a new module. for (const auto& it : mod->functions) { - auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx); + auto updated_func = SkipFunction(it.second) + ? it.second + : pass_func(it.second, updated_mod, pass_ctx); new_mod->Add(it.first, updated_func); } @@ -418,7 +441,7 @@ std::unordered_set SequentialNode::DisabledPasses( std::unordered_set ret; for (const auto& it : disabled) { const auto* str = it.as(); - CHECK(str) << "disabled passes must be string."; + CHECK(str) << "Disabled pass name must be string."; ret.emplace(str->value); } return ret; @@ -429,7 +452,7 @@ std::unordered_set SequentialNode::RequiredPasses( std::unordered_set ret; for (const auto& it : required) { const auto* str = it.as(); - CHECK(str) << "disabled passes must be string."; + CHECK(str) << "Required pass name must be string."; ret.emplace(str->value); } return ret; @@ -439,7 +462,7 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { PassContext ctx = PassContext::Current(); auto required = RequiredPasses(ctx->required_pass); - auto disabled = DisabledPasses(ctx->required_pass); + auto disabled = DisabledPasses(ctx->disabled_pass); if (disabled.count(pass_name)) { return false; @@ -448,29 +471,27 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { if (required.count(pass_name)) { return true; } - return ctx->opt_level >= opt_pass_level[pass_name]; + + const Pass pass = GetPass(pass_name); + PassInfo info = pass->Info(); + return ctx->opt_level >= info->opt_level; } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase -// ordering problem needed to be handled in the future. +// ordering problem needs to be handled in the future. Module SequentialNode::operator()(const Module& module, const PassContext& pass_ctx) const { - int opt_level = pass_ctx->opt_level; - auto disabled = DisabledPasses(pass_ctx->disabled_pass); Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; + PassInfo info = pass->Info(); const auto& pass_name = info->name; - const auto& pass_opt_level = info->opt_level; - // Skip the pass if its optimization level is higher that the one of in the - // pass context or if this pass is disabled. - if (pass_opt_level > opt_level || disabled.count(pass_name)) { - continue; + // Execute the pass if it is enabled. + if (PassEnabled(pass_name)) { + mod = pass(mod, pass_ctx); } - const auto* pn = pass.operator->(); - mod = (*pn)(mod, pass_ctx); } return mod; } @@ -525,15 +546,17 @@ TVM_REGISTER_API("relay._transform.CreateModulePass") TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Pass()(args[1]); + Pass pass = args[0]; + Module mod = args[1]; + *ret = pass(mod); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ModulePassNode* node, tvm::IRPrinter* p) { - const PassInfoNode* pn = node->Info().operator->(); - p->stream << "Run Module pass: " << pn->name - << " at the optimization level " << pn->opt_level; + const PassInfo info = node->Info(); + p->stream << "Run Module pass: " << info->name + << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(FunctionPassNode); @@ -544,9 +567,9 @@ TVM_REGISTER_API("relay._transform.CreateFunctionPass") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, tvm::IRPrinter* p) { - const PassInfoNode* pn = node->Info().operator->(); - p->stream << "Run Function pass: " << pn->name - << " at the optimization level " << pn->opt_level; + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name + << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(SequentialNode); @@ -564,14 +587,13 @@ TVM_REGISTER_API("relay._transform.Sequential") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SequentialNode* node, tvm::IRPrinter* p) { - const PassInfoNode* seq_pn = node->Info().operator->(); - p->stream << "Run Sequential pass: " << seq_pn->name - << " at the optimization level " << seq_pn->opt_level << ". "; + const PassInfo info = node->Info(); + p->stream << "Run Sequential pass: " << info->name + << " at the optimization level " << info->opt_level << ". "; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { - const PassNode* pn = it.operator->(); - const PassInfoNode* pass_info_node = pn->Info().operator->(); - p->stream << pass_info_node->name << " "; + const PassInfo pass_info = it->Info(); + p->stream << pass_info->name << " "; } p->stream << "]"; }); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 8dab0c370853..6d6b24abec20 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "./pattern_util.h" namespace tvm { @@ -105,5 +106,21 @@ Expr SimplifyInference(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.simplify_inference") .set_body_typed(SimplifyInference); +namespace transform { + +Pass SimplifyInference() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(SimplifyInference(f)); + }; + return CreateFunctionPass(pass_func, 0, "SimplifyInference", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.SimplifyInference") +.set_body_typed(SimplifyInference); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index f9d47f78a6d2..324eddd21c5c 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -340,9 +340,12 @@ Pass ToANormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToANormalForm(f, m)); }; - return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); + return CreateFunctionPass(pass_func, 1, "ToANormalForm", {}); } +TVM_REGISTER_API("relay._transform.ToANormalForm") +.set_body_typed(ToANormalForm); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 50ebb702e4b2..9c166f98c1a5 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -86,9 +86,12 @@ Pass ToGraphNormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToGraphNormalForm(f)); }; - return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); + return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } +TVM_REGISTER_API("relay._transform.ToGraphNormalForm") +.set_body_typed(ToGraphNormalForm); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 482cef3b2c2d..3fde3c7e7b36 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -43,6 +43,7 @@ #include #include #include +#include #include "./pass_util.h" #include "type_solver.h" #include "../ir/type_functor.h" @@ -807,5 +808,23 @@ TVM_REGISTER_API("relay._ir_pass.infer_type") .set_body_typed([](const Expr& expr, const Module& mod_ref) { return InferType(expr, mod_ref); }); + +namespace transform { + +Pass InferType() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(InferType(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "InferType", {}); +} + +TVM_REGISTER_API("relay._transform.InferType") +.set_body_typed([]() { + return InferType(); +}); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc new file mode 100644 index 000000000000..b61a5cc0daad --- /dev/null +++ b/tests/cpp/relay_transform_sequential.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TVM_REGISTER_GLOBAL("schedule") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); + }); + +TEST(Relay, Sequential) { + using namespace tvm; + auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, ::tvm::Float(32)); + auto c_data = + tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + // Create a function for optimization. + auto c = relay::ConstantNode::make(c_data); + auto a = relay::VarNode::make("a", tensor_type); + auto x = relay::VarNode::make("x", tensor_type); + auto add_op = relay::Op::Get("add"); + auto y = relay::CallNode::make(add_op, {c, c}); + y = relay::CallNode::make(add_op, {x, y}); + auto z = relay::CallNode::make(add_op, {y, c}); + auto z1 = relay::CallNode::make(add_op, {y, c}); + auto z2 = relay::CallNode::make(add_op, {z, z1}); + // Let expression and varaible a should be dead-code eliminated. + auto z3 = relay::LetNode::make(a, c, z2); + relay::Function func = + relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {}); + + // Get schedule + auto reg = tvm::runtime::Registry::Get("relay.op._Register"); + auto sch = tvm::runtime::Registry::Get("schedule"); + if (!reg || !sch) { + LOG(FATAL) << "Register/schedule is not defined."; + } + + (*reg)("add", "FTVMSchedule", *sch, 10); + + // Run sequential passes. + tvm::Array pass_seqs{ + relay::transform::InferType(), + relay::transform::DeadCodeElimination(), + relay::transform::EliminateCommonSubexpr(), + relay::transform::AlterOpLayout() + }; + relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + auto mod = relay::ModuleNode::FromExpr(func); + auto pass_ctx = relay::transform::PassContext::Create(); + pass_ctx->opt_level = 3; + pass_ctx->fallback_device = 1; + { + tvm::With ctx_scope(pass_ctx); + tvm::With tctx(tvm::Target::Create("llvm")); + mod = seq(mod); + } + + CHECK(mod.defined()); + auto entry_func = mod->entry_func; + CHECK(entry_func.defined()); + relay::Function f = mod->Lookup(entry_func->name_hint); + CHECK(f.defined()); + + // Expected function + auto c1 = relay::ConstantNode::make(c_data); + auto x1 = relay::VarNode::make("x", tensor_type); + auto y1 = relay::CallNode::make(add_op, {c1, c1}); + y1 = relay::CallNode::make(add_op, {x1, y1}); + auto zz = relay::CallNode::make(add_op, {y1, c1}); + zz = relay::CallNode::make(add_op, {zz, zz}); + relay::Function expected_func = + relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); + + // Infer type for the expected function. + auto expected = relay::InferType(expected_func, relay::Module(nullptr)); + CHECK(relay::AlphaEqual(f, expected)); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 2703e5ce1679..7fdef3fa8b9c 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -327,7 +327,8 @@ def test_no_pass(): def test_only_module_pass(): passes = [module_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + with relay.build_config(required_pass=["mod_transform"]): + ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) @@ -341,7 +342,8 @@ def test_only_function_pass(): # Check the subtract function. passes = [function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + with relay.build_config(required_pass=["func_transform"]): + ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -355,7 +357,9 @@ def test_multiple_passes(): mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + required = ["mod_transform", "func_transform"] + with relay.build_config(required_pass=required): + ret_mod = sequential(mod) # Check the abs function is added. abs_var, abs_func = get_var_func() @@ -400,7 +404,48 @@ def test_multiple_passes(): test_multiple_passes() +def test_sequential_with_scoping(): + shape = (1, 2, 3) + c_data = np.array(shape).astype("float32") + tp = relay.TensorType(shape, "float32") + def before(): + c = relay.const(c_data) + x = relay.var("x", tp) + y = relay.add(c, c) + y = relay.multiply(y, relay.const(2, "float32")) + y = relay.add(x, y) + z = relay.add(y, c) + z1 = relay.add(y, c) + z2 = relay.add(z, z1) + return relay.Function([x], z2) + + def expected(): + x = relay.var("x", tp) + c_folded = (c_data + c_data) * 2 + y = relay.add(x, relay.const(c_folded)) + z = relay.add(y, relay.const(c_data)) + z1 = relay.add(z, z) + return relay.Function([x], z1) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.EliminateCommonSubexpr(), + relay.transform.AlterOpLayout() + ]) + + mod = relay.Module({"main": before()}) + with relay.build_config(opt_level=3): + with tvm.target.create("llvm"): + mod = seq(mod) + + zz = mod["main"] + zexpected = ir_pass.infer_type(expected()) + assert relay.ir_pass.alpha_equal(zz, zexpected) + + if __name__ == "__main__": test_module_pass() test_function_pass() test_sequential_pass() + test_sequential_with_scoping() From 439c2f89a6546c071b7afbc847066bc2f572861d Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 4 Jun 2019 18:42:27 +0300 Subject: [PATCH 070/176] [ARITH] Bugfix: int bound analysis for mod (#3288) --- src/arithmetic/const_int_bound.cc | 2 +- tests/python/unittest/test_arith_rewrite_simplify.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 72b85084d59d..ed8faba3509b 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -190,7 +190,7 @@ class ConstIntBoundAnalyzer::Impl : std::min(a.max_value, b_max_cap)); } else { return MakeBound(std::max(a.min_value, -b_max_cap), - std::min(a.max_value, b_max_cap)); + std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); } } else { CHECK(!b.is_const(0)) << "mod by zero"; diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index ee113e101cce..596e54d338b5 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -564,6 +564,7 @@ def test_cmp_simplify(): ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) ck.verify(y*y >= 0, tvm.const(1, "bool")) ck.verify(x*6 <= -3, tvm.const(0, "bool")) + ck.verify((y - 1) % 3 == 0, (y + (-1)) % 3 == 0) def test_logical_simplify(): From 23ce1f2db0d6da21522dbab98651fb39926a29d4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 4 Jun 2019 08:42:47 -0700 Subject: [PATCH 071/176] Bump ONNX version (#3286) --- docker/install/ubuntu_install_onnx.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index ec5b7f3b4964..a073389472b2 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,7 +21,7 @@ set -u set -o pipefail # fix to certain version for now -pip3 install onnx>=1.1.0 +pip3 install onnx>=1.4.1 pip3 install https://download.pytorch.org/whl/cu80/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl pip3 install torchvision From c861aa3401b8652afb9b729fb988e4e53175ecab Mon Sep 17 00:00:00 2001 From: Hua Date: Tue, 4 Jun 2019 09:47:29 -0700 Subject: [PATCH 072/176] [Bugfix] [VTA] VTA DRAM Have A Logic Issue May Cause GEMM Output Wrong. (#3278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Bugfix] [VTA] VTA DRAM Have A Logic Issue May Cause GEMM Output Wrong. Symptom: after change “LOG_BLOCK_IN” and “LOG_BLOCK_OUT” from vta_config.json into 7, run vta "Simple Matrix Multiply" in "simulator", the vta calculate result for GEMM is wrong. Sometime VTA crash with error “Check failed: phy_addr != 0 (0 vs. 0) : trying to get address that is nullptr” Analysis: Simulator hardcode kPageSize into 1<<12 and physical address calculate based on this size, when doing “insn->dram_base” calculation , because GetElemBytes(dst_memory_type) larger than page size, different physcial address may get same dram_base, than caused logic issue and finally trigger GEMM out put is wrong. Solution: add logic to check if PAGE SIZE larger then "GetElemBytes" return value. * address review comments. --- vta/include/vta/driver.h | 4 ++++ vta/src/runtime.cc | 33 +++++++++++++++++++++++++-------- vta/src/sim/sim_driver.cc | 4 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/vta/include/vta/driver.h b/vta/include/vta/driver.h index ed041853b117..d583051dc194 100644 --- a/vta/include/vta/driver.h +++ b/vta/include/vta/driver.h @@ -45,6 +45,10 @@ extern "C" { #define VTA_MAX_XFER (1<<22) #endif +/*! PAGE SIZE */ +#define VTA_PAGE_BITS 12 +#define VTA_PAGE_BYTES (1 << VTA_PAGE_BITS) + /*! \brief Device resource context */ typedef void * VTADeviceHandle; diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index 7af0de1a8f8b..79a407fe521e 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -913,16 +913,33 @@ class CommandQueue { } uint32_t GetElemBytes(uint32_t memory_id) { + uint32_t elem_bytes = 0; switch (memory_id) { - case VTA_MEM_ID_UOP: return VTA_UOP_ELEM_BYTES; - case VTA_MEM_ID_INP: return VTA_INP_ELEM_BYTES; - case VTA_MEM_ID_WGT: return VTA_WGT_ELEM_BYTES; - case VTA_MEM_ID_ACC: return VTA_ACC_ELEM_BYTES; - case VTA_MEM_ID_OUT: return VTA_INP_ELEM_BYTES; - default: break; + case VTA_MEM_ID_UOP: + elem_bytes = VTA_UOP_ELEM_BYTES; + break; + case VTA_MEM_ID_INP: + elem_bytes = VTA_INP_ELEM_BYTES; + break; + case VTA_MEM_ID_WGT: + elem_bytes = VTA_WGT_ELEM_BYTES; + break; + case VTA_MEM_ID_ACC: + elem_bytes = VTA_ACC_ELEM_BYTES; + break; + case VTA_MEM_ID_OUT: + elem_bytes = VTA_INP_ELEM_BYTES; + break; + default: + LOG(FATAL) << "Memory id not recognized:" << memory_id; + break; } - LOG(FATAL) << "Memory id not recognized:" << memory_id; - return 0; + /* + * elements size should not larger than VTA_PAGE_BYTES. + * + */ + CHECK_GE(VTA_PAGE_BYTES, elem_bytes); + return elem_bytes; } void LoadBuffer2D(void* src_dram_addr, diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc index 803a54d6b96a..5f9f6b637599 100644 --- a/vta/src/sim/sim_driver.cc +++ b/vta/src/sim/sim_driver.cc @@ -196,9 +196,9 @@ class DRAM { private: // The bits in page table - static constexpr vta_phy_addr_t kPageBits = 12; + static constexpr vta_phy_addr_t kPageBits = VTA_PAGE_BITS; // page size, also the maximum allocable size 16 K - static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits; + static constexpr vta_phy_addr_t kPageSize = VTA_PAGE_BYTES; /*! \brief A page in the DRAM */ struct Page { /*! \brief Data Type */ From 21daf8eecaba5f1e8000430025d0bbfd400dbab7 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 4 Jun 2019 13:28:36 -0700 Subject: [PATCH 073/176] [Relay][Docs] Add parser dependency install instructions. (#3277) * [Relay][Docs] Add parser dependency install instructions. See https://discuss.tvm.ai/t/trouble-enabling-antlr/2783. * Add a word. * Update since the parser will now be committed to the repo. * revert b/c adding the parser doesn't fix this --- docs/install/from_source.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 3a769dee2dce..1ea8f3478341 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -192,6 +192,12 @@ Python dependencies .. code:: bash pip install --user tornado psutil xgboost + + * If you want to parse Relay text format progams, you must use Python 3 and run the following + + .. code:: bash + + pip install --user mypy orderedset antlr4-python3-runtime Install Contrib Libraries From 4e90aac9470db23a3b4df666505d8ee27428a284 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 4 Jun 2019 16:29:56 -0700 Subject: [PATCH 074/176] [Relay/TOPI][Op] Add TopK operator (#3256) * init impl for topk * Fix cpu for topk * init cuda impl for topk * Add cuda for topk * fix * Add doc * update doc * lint * lint * lint * x * fix warning * [Relay] Add TopK in tf converter * Add frontend converter * fix --- docs/api/python/topi.rst | 4 + docs/langref/relay_op.rst | 2 + include/tvm/relay/attrs/algorithm.h | 25 ++ python/tvm/relay/frontend/mxnet.py | 16 ++ python/tvm/relay/frontend/tensorflow.py | 15 ++ python/tvm/relay/op/_algorithm.py | 30 ++- python/tvm/relay/op/algorithm.py | 44 ++- python/tvm/relay/op/nn/nn.py | 2 +- python/tvm/relay/op/transform.py | 6 +- src/codegen/build_module.cc | 2 +- src/contrib/sort/sort.cc | 251 +++++++++++++++--- .../op/algorithm/{sort.cc => argsort.cc} | 8 +- src/relay/op/algorithm/topk.cc | 101 +++++++ tests/python/frontend/mxnet/test_forward.py | 41 +++ .../frontend/tensorflow/test_forward.py | 19 ++ tests/python/relay/test_op_level6.py | 62 ++++- topi/python/topi/cuda/__init__.py | 1 + topi/python/topi/cuda/nms.py | 2 +- topi/python/topi/cuda/sort.py | 241 ++++++++++++----- topi/python/topi/generic/sort.py | 17 ++ topi/python/topi/sort.py | 83 ++++-- topi/python/topi/transform.py | 2 + topi/python/topi/vision/nms.py | 2 +- topi/tests/python/test_topi_sort.py | 76 +++++- 24 files changed, 904 insertions(+), 148 deletions(-) rename src/relay/op/algorithm/{sort.cc => argsort.cc} (94%) create mode 100644 src/relay/op/algorithm/topk.cc diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 0b217d4fe3af..ade0f1a5b390 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -99,6 +99,8 @@ List of operators topi.shape topi.layout_transform topi.image.resize + topi.argsort + topi.topk List of schedules @@ -163,6 +165,8 @@ topi .. autofunction:: topi.tile .. autofunction:: topi.shape .. autofunction:: topi.layout_transform +.. autofunction:: topi.argsort +.. autofunction:: topi.topk topi.nn ~~~~~~~ diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 836f8f30bfa8..28ee99e77981 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -172,6 +172,7 @@ This level enables additional math and transform operators. :nosignatures: tvm.relay.argsort + tvm.relay.topk **Level 10: Temporary Operators** @@ -309,6 +310,7 @@ Level 5 Definitions Level 6 Definitions ------------------- .. autofunction:: tvm.relay.argsort +.. autofunction:: tvm.relay.topk Level 10 Definitions diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 20f135c11bba..f5ba6999347f 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode { } }; +struct TopKAttrs : public tvm::AttrsNode { + int k; + int axis; + bool is_ascend; + std::string ret_type; + DataType dtype; + + TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") { + TVM_ATTR_FIELD(k).set_default(1) + .describe("Number of top elements to select"); + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Axis along which to sort the input tensor."); + TVM_ATTR_FIELD(ret_type).set_default("both") + .describe("The return type [both, values, indices]." + "both - return both top k data and indices." + "values - return top k data only." + "indices - return top k indices only."); + TVM_ATTR_FIELD(is_ascend).set_default(false) + .describe("Whether to sort in ascending or descending order." + "By default, sort in descending order"); + TVM_ATTR_FIELD(dtype).set_default(NullValue()) + .describe("Data type of the output indices."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0bc7923648ff..0975a33450c8 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs): return _op.argsort(inputs[0], **new_attrs) +def _mx_topk(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["k"] = attrs.get_int("k", 1) + new_attrs["axis"] = attrs.get_int("axis", -1) + new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True) + ret_type = attrs.get_str("ret_typ", "indices") + if ret_type == "mask": + raise tvm.error.OpAttributeUnimplemented( + "Attribute ret_type=mask is not supported in topk operator") + new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type + new_attrs["dtype"] = attrs.get_str("dtype", "float32") + return _op.topk(inputs[0], **new_attrs) + + def _mx_rnn_param_concat(inputs, _): # We don't need to concatenate RNN params because we will unravel the RNN op return [inputs] @@ -914,6 +929,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "shape_array" : _mx_shape_array, "Embedding" : _mx_embedding, "argsort" : _mx_argsort, + "topk" : _mx_topk, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, "LinearRegressionOutput" : _mx_linear_regression_output, diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7fe82ea7eac1..307fb20693f4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1082,6 +1082,20 @@ def _impl(inputs, attr, params): return _get_relay_op('log')(add_out) return _impl +def _topk(): + def _impl(inputs, attr, params): + k = int(params.pop(inputs.pop(1).name_hint).asnumpy()) + if k < 1: + raise tvm.error.OpAttributeInvalid( + 'Attribute k must be positive in operator TopKV2') + if attr['sorted'] is False: + raise tvm.error.OpAttributeUnimplemented( + 'Attribute sorted=False is not supported in operator TopKV2') + return AttrCvt(op_name='topk', + ignores=['sorted'], + extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) + return _impl + def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) @@ -1271,6 +1285,7 @@ def _impl(inputs, attr, params): 'Sum' : _sum(), 'Tanh' : AttrCvt('tanh'), 'Tile' : _tile(), + 'TopKV2' : _topk(), 'Transpose' : _transpose(), 'Unpack' : _unpack(), diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 57e716534ee5..09746be13e30 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target): """Compute definition of argsort""" axis = get_const_int(attrs.axis) is_ascend = bool(get_const_int(attrs.is_ascend)) - dtype = str(attrs.dtype) - return [ - topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \ - dtype=dtype, flag=False) - ] + dtype = attrs.dtype + return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)] register_pattern("argsort", OpPattern.OPAQUE) + + +@register_schedule("topk") +def schedule_topk(_, outs, target): + """Schedule definition of argsort""" + with target: + return topi.generic.schedule_topk(outs) + + +@register_compute("topk") +def compute_topk(attrs, inputs, _, target): + """Compute definition of argsort""" + k = get_const_int(attrs.k) + axis = get_const_int(attrs.axis) + ret_type = attrs.ret_type + is_ascend = bool(get_const_int(attrs.is_ascend)) + dtype = attrs.dtype + out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype) + out = out if isinstance(out, list) else [out] + return out + + +register_pattern("topk", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 6451eb41aeb9..6f875919df4c 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,8 +17,9 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make +from ..expr import TupleWrapper -def argsort(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): Whether to sort in ascending or descending order. dtype : string, optional - DType of the output indices. + The data type of the output indices. Returns ------- @@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): Tensor with same shape as data. """ return _make.argsort(data, axis, is_ascend, dtype) + + +def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): + """Get the top k elements in an input tensor along the given axis. + + ret_type specifies the return type, can be one of ("both", "values", "indices"). + + Parameters + ---------- + data : relay.Expr + The input data tensor. + + k : int, optional + Number of top elements to select. Return all elements if k < 1. + + axis : int, optional + Axis long which to sort the input tensor. + + ret_type: str, optional + The return type [both, values, indices]. + "both": return both top k data and indices. + "values": return top k data only. + "indices": return top k indices only. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + The data type of the indices output. + + Returns + ------- + out : relay.Expr or List[relay.Expr] + The computed result. + """ + out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) + if ret_type == "both": + return TupleWrapper(out, 2) + return out diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b772c43e11cd..b4ebffb355d0 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -401,7 +401,7 @@ def upsampling(data, with data of shape (n, c, h, w) out will have a shape (n, c, h*scale, w*scale) - method indicates the algorithm to be used while calculating ghe out value + method indicates the algorithm to be used while calculating the out value and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") Parameters diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9c76b7e569dc..02fd4924b804 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"): the flattened input array is used. mode : str, optional - Specifies how out-of-bound indices will behave. - clip - clip to the range (default) - wrap - wrap around the indices + Specifies how out-of-bound indices will behave [clip, wrap]. + clip: clip to the range (default). + wrap: wrap around the indices. Returns ------- diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 834b4eea7e3f..0a488f38457b 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name, t->device_type = kDLGPU; t->keys_array.push_back(ir::StringImm::make("cuda")); t->keys_array.push_back(ir::StringImm::make("gpu")); - t->max_num_threads = 512; + t->max_num_threads = 1024; t->thread_warp_size = 32; } else if (target_name == "rocm" || target_name == "opencl") { // For now assume rocm schedule for opencl diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index cf25e89b9109..87691f254c5c 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -34,14 +34,14 @@ namespace contrib { using namespace runtime; template -bool CompareAscend(const std::pair& lhs, - const std::pair& rhs) { +bool CompareAscend(const std::pair& lhs, + const std::pair& rhs) { return lhs.second < rhs.second; } template -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { +bool CompareDescend(const std::pair& lhs, + const std::pair& rhs) { return lhs.second > rhs.second; } @@ -110,6 +110,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") } }); +template +void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { + auto data_ptr = static_cast(input->data); + auto out_ptr = static_cast(output->data); + std::vector > sorter; + + int axis_mul_before = 1; + int axis_mul_after = 1; + for (int i = 0; i < input->ndim; ++i) { + if (i < axis) { + axis_mul_before *= input->shape[i]; + } else if (i > axis) { + axis_mul_after *= input->shape[i]; + } + } + + for (int i = 0 ; i < axis_mul_before; ++i) { + for (int j = 0 ; j < axis_mul_after; ++j) { + sorter.clear(); + int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; + for (int64_t k = 0; k < input->shape[axis]; ++k) { + int64_t full_idx = base_idx + k * axis_mul_after; + sorter.emplace_back(std::make_pair(k, data_ptr[full_idx])); + } + if (is_ascend) { + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + } + for (int64_t k = 0; k < input->shape[axis]; ++k) { + out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].first); + } + } + } +} // Argsort implemented C library sort. // Return indices of sorted tensor. @@ -124,25 +159,84 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") DLTensor *output = args[1]; int32_t axis = args[2]; bool is_ascend = args[3]; - - auto dtype = input->dtype; - auto data_ptr = static_cast(input->data); - std::vector> sorter; - int64_t axis_mul_before = 1; - int64_t axis_mul_after = 1; - if (axis < 0) { axis = input->ndim + axis; } - - // Currently only supports input dtype to be float32. - CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float32."; - CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " << input->ndim; + + auto data_dtype = TVMType2String(input->dtype); + auto out_dtype = TVMType2String(output->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } +}); +template +void topk(DLTensor* input, + DLTensor* out_values, + DLTensor* out_indices, + int k, + int axis, + bool is_ascend) { + DataType* data_ptr = static_cast(input->data); + DataType* values_ptr = (out_values == nullptr) ? nullptr : + static_cast(out_values->data); + IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : + static_cast(out_indices->data); + std::vector > sorter; + + int axis_mul_before = 1; + int axis_mul_after = 1; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { axis_mul_before *= input->shape[i]; @@ -150,27 +244,124 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") axis_mul_after *= input->shape[i]; } } + if (k < 1) { + k = input->shape[axis]; + } - int32_t current_sort_num = input->shape[axis]; - for (int64_t i = 0 ; i < axis_mul_before; ++i) { - for (int64_t j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0 ; i < axis_mul_before; ++i) { + for (int j = 0 ; j < axis_mul_after; ++j) { sorter.clear(); - int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; - for (int64_t k = 0; k < current_sort_num; ++k) { - int64_t full_idx = base_idx + k * axis_mul_after; - sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); + int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; + int64_t dst_base_idx = i * k * axis_mul_after + j; + for (int64_t kk = 0; kk < input->shape[axis]; ++kk) { + int64_t full_idx = src_base_idx + kk * axis_mul_after; + sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx])); } if (is_ascend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); } else { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } - for (int32_t k = 0; k < input->shape[axis]; ++k) { - *(static_cast(output->data) + base_idx + k * axis_mul_after) - = k < static_cast(sorter.size()) ? sorter[k].first : k; + int64_t cnt = k > 0 ? k : input->shape[axis]; + for (int64_t kk = 0; kk < cnt; ++kk) { + if (indices_ptr != nullptr) { + indices_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(sorter[kk].first); + } + if (values_ptr != nullptr) { + values_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(sorter[kk].second); + } } } } +} + +// Argsort implemented C library sort. +// Return indices of sorted tensor. +// By default, the last axis will be used to sort. +// sort_num specify the number of elements to be sorted. +// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) +// and sort axis is dk. sort_num should have dimension of +// (d1, d2, ..., d(k-1), d(k+1), ..., dn). +TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* values_out = nullptr; + DLTensor* indices_out = nullptr; + int k = args[args.num_args - 4]; + int axis = args[args.num_args - 3]; + std::string ret_type = args[args.num_args - 2]; + bool is_ascend = args[args.num_args - 1]; + if (ret_type == "both") { + values_out = args[1]; + indices_out = args[2]; + } else if (ret_type == "values") { + values_out = args[1]; + } else if (ret_type == "indices") { + indices_out = args[1]; + } else { + LOG(FATAL) << "Unsupported ret type: " << ret_type; + } + if (axis < 0) { + axis = input->ndim + axis; + } + CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; + + auto data_dtype = TVMType2String(input->dtype); + auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } }); } // namespace contrib diff --git a/src/relay/op/algorithm/sort.cc b/src/relay/op/algorithm/argsort.cc similarity index 94% rename from src/relay/op/algorithm/sort.cc rename to src/relay/op/algorithm/argsort.cc index 5777b79699b1..31aa88808a23 100644 --- a/src/relay/op/algorithm/sort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2018 by Contributors - * \file nms.cc - * \brief Non-maximum suppression operators + * Copyright (c) 2019 by Contributors + * \file argsort.cc + * \brief Argsort operators */ #include #include @@ -44,7 +44,6 @@ bool ArgsortRel(const Array& types, << types[0]; return false; } - CHECK_EQ(param->dtype, Float(32)); reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); return true; } @@ -74,5 +73,6 @@ input array along the given axis. .add_argument("data", "Tensor", "Input data.") .set_support_level(6) .add_type_rel("Argsort", ArgsortRel); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc new file mode 100644 index 000000000000..c88e2c3ea007 --- /dev/null +++ b/src/relay/op/algorithm/topk.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file topk.cc + * \brief TopK operators + */ +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(TopKAttrs); + +bool TopKRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + const TopKAttrs* param = attrs.as(); + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data); + int ndim = data->shape.size(); + int axis = param->axis; + if (axis < 0) { + axis += ndim; + } + CHECK(axis >= 0 && axis < ndim); + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i != axis || param->k < 1) { + out_shape.push_back(data->shape[i]); + } else { + out_shape.push_back(param->k); + } + } + auto values_ty = TensorTypeNode::make(out_shape, data->dtype); + auto indices_ty = TensorTypeNode::make(out_shape, param->dtype); + if (param->ret_type == "both") { + reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty})); + } else if (param->ret_type == "values") { + reporter->Assign(types[1], values_ty); + } else if (param->ret_type == "indices") { + reporter->Assign(types[1], indices_ty); + } else { + LOG(FATAL) << "Unsupported ret type: " << param->ret_type; + } + return true; +} + +Expr MakeTopK(Expr data, + int k, + int axis, + std::string ret_type, + bool is_ascend, + DataType dtype) { + auto attrs = make_node(); + attrs->k = k; + attrs->axis = axis; + attrs->ret_type = ret_type; + attrs->is_ascend = is_ascend; + attrs->dtype = dtype; + static const Op& op = Op::Get("topk"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op._make.topk") +.set_body_typed(MakeTopK); + +RELAY_REGISTER_OP("topk") +.describe(R"doc(Get the top k elements in an input tensor along the given axis. +)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.TopKAttrs") +.add_argument("data", "Tensor", "Input data.") +.set_support_level(6) +.add_type_rel("TopK", TopKRel); + +} // namespace relay +} // namespace tvm + diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 50a25a9aff61..7569257830af 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -608,6 +608,45 @@ def verify(xshape, yshape, offset=None): verify((5, 32, 40, 40), (5, 32, 25, 25)) verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5)) +def test_forward_argsort(): + def verify(shape, axis, is_ascend, dtype="float32"): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype) + mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((2, 3, 4), axis=0, is_ascend=False) + verify((1, 4, 6), axis=1, is_ascend=True) + verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32") + +def test_forward_topk(): + def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, + is_ascend=is_ascend, dtype=dtype) + mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, + is_ascend=is_ascend, dtype=dtype) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + if isinstance(ref_res, list): + assert len(op_res) == len(ref_res) + for i, t in enumerate(op_res): + tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy()) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((3, 4), k=1, axis=0, ret_type="both") + verify((3, 4), k=1, axis=-1, ret_type="indices") + verify((3, 5, 6), k=2, axis=2, ret_type="value") + verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True) + verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32") + if __name__ == '__main__': test_forward_mlp() @@ -650,3 +689,5 @@ def verify(xshape, yshape, offset=None): test_forward_bilinear_resize() test_forward_rnn_layer() test_forward_Crop() + test_forward_argsort() + test_forward_topk() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 023cdf5eb261..eebb73c95b1b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -754,6 +754,24 @@ def test_forward_split(): _test_split((3, 6, 4), -2, [1, 4, 1], 'float32') +###################################################################### +# TopKV2 +# ------ + +def _test_forward_top_k_v2(in_shape, k): + np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32") + tf.reset_default_graph() + in_data = tf.placeholder("float32", in_shape, name="in_data") + tf.math.top_k(in_data, k, name='TopK') + compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0') + +def test_forward_top_k_v2(): + _test_forward_top_k_v2((3,), 1) + _test_forward_top_k_v2((3,), 3) + _test_forward_top_k_v2((3, 5, 7), 3) + _test_forward_top_k_v2((3, 5, 7), 3) + + ####################################################################### # Unstack # ------- @@ -1704,6 +1722,7 @@ def test_placeholder(): test_forward_split() test_forward_unstack() test_forward_tile() + test_forward_top_k_v2() # Activations test_forward_sigmoid() diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 983a9154df34..76478baf5a19 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -16,18 +16,15 @@ # under the License. """ Support level6 operator test cases. """ -import math import numpy as np import tvm from tvm import relay from tvm.relay.testing import ctx_list -import topi.testing def test_argsort(): - def verify_argsort(shape, axis, is_ascend): + def verify_argsort(shape, axis, is_ascend, dtype): x = relay.var("x", relay.TensorType(shape, "float32")) - z = relay.argsort(x, axis=axis, is_ascend=is_ascend) - zz = relay.ir_pass.infer_type(z) + z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype) func = relay.Function([x], z) x_data = np.random.uniform(size=shape).astype("float32") if is_ascend: @@ -39,11 +36,58 @@ def verify_argsort(shape, axis, is_ascend): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5) - verify_argsort((2, 3, 4), axis=0, is_ascend=False) - verify_argsort((1, 4, 6), axis=1, is_ascend=True) - verify_argsort((3, 5, 6), axis=-1, is_ascend=False) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5) + for dtype in ["int32", "int64", "float32", "float64"]: + verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype) + verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype) + verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype) + + +def test_topk(): + def verify_topk(k, axis, ret_type, is_ascend, dtype): + shape = (20, 100) + x = relay.var("x", relay.TensorType(shape, "float32")) + out = relay.topk(x, k, axis, ret_type, is_ascend, dtype) + if isinstance(out, relay.expr.TupleWrapper): + out = out.astuple() + func = relay.Function([x], out) + np_data = np.random.uniform(size=shape).astype("float32") + if is_ascend: + np_indices = np.argsort(np_data, axis=axis) + else: + np_indices = np.argsort(-np_data, axis=axis) + kk = k if k >= 1 else shape[axis] + if axis == 0: + np_indices = np_indices[:kk, :] + np_values = np.zeros(np_indices.shape).astype("float32") + for i in range(shape[1]): + np_values[:, i] = np_data[np_indices[:, i], i] + else: + np_indices = np_indices[:, :kk] + np_values = np.zeros(np_indices.shape).astype("float32") + for i in range(shape[0]): + np_values[i, :] = np_data[i, np_indices[i, :]] + np_indices = np_indices.astype(dtype) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(np_data) + if ret_type == "both": + tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values) + tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices) + elif ret_type == "values": + tvm.testing.assert_allclose(op_res.asnumpy(), np_values) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), np_indices) + for k in [0, 1, 5]: + for axis in [0, -1, 1]: + for ret_type in ["both", "values", "indices"]: + for dtype in ["int64", "float32"]: + verify_topk(k, axis, ret_type, False, dtype) + verify_topk(k, axis, ret_type, True, dtype) if __name__ == "__main__": test_argsort() + test_topk() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 526429b91bee..403f67b972f7 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -21,3 +21,4 @@ from .ssd import * from .nms import * from .rcnn import * +from .sort import * diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 925cf24acd11..911dd84e2f05 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 678d494dae50..1d9148f46278 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -19,19 +19,48 @@ import tvm from tvm import api -from topi.sort import argsort -from topi.math import identity +from ..sort import argsort, topk +from ..math import identity +from ..transform import strided_slice from .. import generic from .. import tag +def _schedule_sort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argsort + in the format of an array of tensors. -def sort_ir(data, output, axis, is_ascend): + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + from .injective import _schedule_injective + def traverse(op): + if tag.is_injective(op.tag): + _schedule_injective(op, s) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + for out in outs: + traverse(out.op) + return s + +def sort_ir(data, values_out, axis, is_ascend, indices_out=None): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. Parameters ---------- data: Buffer - Buffer of input data. + Buffer of input data. Data will be sorted in place. output : Buffer Output buffer of indicies of sorted tensor with same shape as data. @@ -47,14 +76,12 @@ def sort_ir(data, output, axis, is_ascend): stmt : Stmt The result IR statement. """ - size = 1 axis_mul_before = 1 axis_mul_after = 1 shape = data.shape if axis < 0: axis = len(shape) + axis for i, value in enumerate(shape, 0): - size *= value if i < axis: axis_mul_before *= value elif i > axis: @@ -62,52 +89,62 @@ def sort_ir(data, output, axis, is_ascend): max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() data = ib.buffer_ptr(data) - output = ib.buffer_ptr(output) + values_out = ib.buffer_ptr(values_out) + if indices_out is not None: + indices_out = ib.buffer_ptr(indices_out) nthread_tx = max_threads - nthread_bx = size // max_threads + 1 + nthread_bx = shape[axis] // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("vthread") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "virtual_thread", nthread_bx) tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local") - is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local") + if indices_out is not None: + temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local") with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_after) as j: - current_sort_num = shape[axis] base_idx = i * shape[axis] * axis_mul_after + j with ib.if_scope(tid < shape[axis]): - output[base_idx + tid * axis_mul_after] = tid.astype("float32") + values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after] + if indices_out is not None: + indices_out[base_idx + tid * axis_mul_after] = \ + tvm.generic.cast(tid, indices_out.dtype) + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_after) as j: + current_sort_num = shape[axis] + base_idx = i * shape[axis] * axis_mul_after + j # OddEvenTransposeSort with ib.for_range(0, current_sort_num) as k: with ib.if_scope(tid < (current_sort_num + 1) // 2): offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after - with ib.if_scope(tvm.all(is_ascend == 1, \ - 2 * tid + (k % 2) + 1 < current_sort_num, \ - data[offset] > data[offset + axis_mul_after])): - temp_data[0] = data[offset] - data[offset] = data[offset + axis_mul_after] - data[offset + axis_mul_after] = temp_data[0] - temp_index[0] = output[offset] - output[offset] = output[offset + axis_mul_after] - output[offset + axis_mul_after] = temp_index[0] - with ib.if_scope(tvm.all(is_ascend == 0, \ - 2 * tid + (k % 2) + 1 < current_sort_num, \ - data[offset] < data[offset + axis_mul_after])): - temp_data[0] = data[offset] - data[offset] = data[offset + axis_mul_after] - data[offset + axis_mul_after] = temp_data[0] - temp_index[0] = output[offset] - output[offset] = output[offset + axis_mul_after] - output[offset + axis_mul_after] = temp_index[0] + if is_ascend: + cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, + values_out[offset] > values_out[offset + axis_mul_after]) + else: + cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, + values_out[offset] < values_out[offset + axis_mul_after]) + with ib.if_scope(cond): + temp_data[0] = values_out[offset] + values_out[offset] = values_out[offset + axis_mul_after] + values_out[offset + axis_mul_after] = temp_data[0] + if indices_out is not None: + temp_index[0] = indices_out[offset] + indices_out[offset] = indices_out[offset + axis_mul_after] + indices_out[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) return ib.get() + def sort_nms_ir(data, valid_count, output, axis, is_ascend): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. @@ -197,7 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): return ib.get() @argsort.register(["cuda", "gpu"]) -def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): +def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -206,26 +243,27 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 data: tvm.Tensor The input array. - valid_count : tvm.Tensor + valid_count : tvm.Tensor, optional The number of valid elements to be sorted. - axis : int + axis : int, optional Axis long which to sort the input tensor. - is_ascend : boolean + is_ascend : boolean, optional Whether to sort in ascending or descending order. - flag : boolean - Whether this argsort is used in nms operator + dtype : string, optional + DType of the output indices. Returns ------- out : tvm.Tensor The output of this function. """ - sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) - sorted_data = identity(data) - if flag: + if valid_count is not None: + sorted_data = identity(data) + sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", + data_alignment=8) valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) @@ -239,16 +277,15 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 name="argsort_nms_gpu", tag="argsort_nms_gpu") else: - out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) - out = tvm.extern([data.shape], - [sorted_data], + value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + out = tvm.extern([data.shape, data.shape], + [data], lambda ins, outs: sort_ir( - ins[0], outs[0], axis, is_ascend), - dtype=dtype, - in_buffers=[sorted_data_buf], - out_buffers=[out_buf], + ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), + out_buffers=[value_buf, indices_buf], name="argsort_gpu", - tag="argsort_gpu") + tag="argsort_gpu")[1] return out @generic.schedule_argsort.register(["cuda", "gpu"]) @@ -266,17 +303,99 @@ def schedule_argsort(outs): s: Schedule The computation schedule for the op. """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - from .injective import _schedule_injective - def traverse(op): - if tag.is_broadcast(op.tag): - _schedule_injective(op, s) - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - traverse(outs[0].op) + return _schedule_sort(outs) - return s +@topk.register(["cuda", "gpu"]) +def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): + """Get the top k elements in an input tensor along the given axis. + + Parameters + ---------- + data : tvm.Tensor + The input tensor. + + k : int, optional + Number of top elements to select. Return all elements if k < 1. + + axis : int, optional + Axis long which to sort the input tensor. + + ret_type: str, optional + The return type [both, values, indices]. + "both": return both top k data and indices. + "values": return top k data only. + "indices": return top k indices only. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + The data type of the indices output. + + Returns + ------- + out : tvm.Tensor or List[tvm.Tensor] + The computed result. + """ + assert ret_type in ["both", "values", "indices"] + ndim = len(data.shape) + axis = axis + ndim if axis < 0 else axis + assert 0 <= axis < ndim + values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) + indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) + if ret_type == "values": + output = tvm.extern([data.shape], + [data], + lambda ins, outs: sort_ir( + ins[0], outs[0], axis, is_ascend), + out_buffers=[values_buf], + name="topk_gpu", + tag="topk_gpu") + else: + output = tvm.extern([data.shape, data.shape], + [data], + lambda ins, outs: sort_ir( + ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), + out_buffers=[values_buf, indices_buf], + name="topk_gpu", + tag="topk_gpu") + if k < 1: + if ret_type == "indices": + return output[1] + return output + beg = [0] * ndim + end = [] + for i in range(ndim): + if i == axis: + end.append(k) + else: + end.append(data.shape[i]) + if ret_type == "both": + values_out, indices_out = output + values_out = strided_slice(values_out, beg, end) + indices_out = strided_slice(indices_out, beg, end) + output = [values_out, indices_out] + elif ret_type == "values": + output = [strided_slice(output, beg, end)] + else: # ret_type == "indices" + indices_out = output[1] + output = [strided_slice(indices_out, beg, end)] + return output + + +@generic.schedule_topk.register(["cuda", "gpu"]) +def schedule_topk(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argsort + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _schedule_sort(outs) diff --git a/topi/python/topi/generic/sort.py b/topi/python/topi/generic/sort.py index 1ad088c50d04..5462f2ce917c 100644 --- a/topi/python/topi/generic/sort.py +++ b/topi/python/topi/generic/sort.py @@ -36,3 +36,20 @@ def schedule_argsort(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + +@tvm.target.generic_func +def schedule_topk(outs): + """Schedule for topk operator. + + Parameters + ---------- + outs: Array of Tensor + The indices that would sort an input array along + the given axis. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py index 84fff8d8f0cd..22899c4232f7 100644 --- a/topi/python/topi/sort.py +++ b/topi/python/topi/sort.py @@ -18,9 +18,10 @@ """Argsort operator""" import tvm from tvm import api +from .util import get_const_tuple @tvm.target.generic_func -def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): +def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indices having the same shape as an input array that index data in sorted order. @@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): data : tvm.Tensor The input tensor. - valid_count : tvm.Tensor + valid_count : tvm.Tensor, optional 1-D tensor for valid number of boxes only for ssd. - axis : optional, int - Axis along which to sort the input tensor. + axis : int, optional + Axis along which to sort the input tensor. By default the flattened array is used. - is_ascend : optional, boolean + is_ascend : boolean, optional Whether to sort in ascending or descending order. - dtype : optional, string + dtype : string, optional DType of the output indices. - flag : optional, boolean - Whether valid_count is valid. - Returns ------- out : tvm.Tensor @@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): # An example to use argsort dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") - valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") axis = 0 is_ascend = False - flag = False - out = argsort(data, valid_count, axis, is_ascend, flag) + out = argsort(data, axis=axis, is_ascend=is_ascend) np_data = np.random.uniform(dshape) - np_valid_count = np.array([4]) s = topi.generic.schedule_argsort(out) - f = tvm.build(s, [data, valid_count, out], "llvm") + f = tvm.build(s, [data, out], "llvm") ctx = tvm.cpu() tvm_data = tvm.nd.array(np_data, ctx) - tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) - f(tvm_data, tvm_valid_count, tvm_out) + f(tvm_data, tvm_out) """ data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - if flag: + if valid_count is not None: valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) @@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): name="argsort_cpu", tag="argsort_cpu") return out + + +@tvm.target.generic_func +def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): + """Get the top k elements in an input tensor along the given axis. + + Parameters + ---------- + data : tvm.Tensor + The input tensor. + + k : int, optional + Number of top elements to select. Return all elements if k < 1. + + axis : int, optional + Axis long which to sort the input tensor. + + ret_type: str, optional + The return type [both, values, indices]. + "both": return both top k data and indices. + "values": return top k data only. + "indices": return top k indices only. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + The data type of the indices output. + + Returns + ------- + out : tvm.Tensor or List[tvm.Tensor] + The computed result. + """ + assert ret_type in ["both", "values", "indices"] + data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + out_shape = list(get_const_tuple(data.shape)) + if k >= 1: + out_shape[axis] = k + out_bufs = [] + if ret_type in ["both", "values"]: + out_bufs.append(api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8)) + if ret_type in ["both", "indices"]: + out_bufs.append(api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8)) + out_shapes = [out_shape] * len(out_bufs) + + out = tvm.extern(out_shapes, + [data], + lambda ins, outs: tvm.call_packed( + "tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend), + in_buffers=[data_buf], + out_buffers=out_bufs, + name="topk_cpu", + tag="topk_cpu") + return out diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 2ad1f6e10057..04af1513576b 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None): ------- ret : tvm.Tensor """ + if strides is None: + strides = [] return cpp.strided_slice(a, begin, end, strides) diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index 979565d31662..7c8d7db33059 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) - sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) out, box_indices = hybrid_nms(data, sort_tensor, valid_count, tvm.const(max_output_size, dtype="int32"), tvm.const(iou_threshold, dtype="float32"), diff --git a/topi/tests/python/test_topi_sort.py b/topi/tests/python/test_topi_sort.py index 3a2c9c2e4980..ed902b982a2b 100644 --- a/topi/tests/python/test_topi_sort.py +++ b/topi/tests/python/test_topi_sort.py @@ -16,23 +16,15 @@ # under the License. """Test code for vision package""" from __future__ import print_function -import math import numpy as np import tvm import topi import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple -from topi import argsort - def test_argsort(): - dshape = (1, 8) - valid_count_shape = (2,) + dshape = (20, 100) data = tvm.placeholder(dshape, name="data", dtype="float32") - valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) - np_valid_count = np.array([4]).astype(valid_count.dtype) np_result = np.argsort(-np_data) def check_device(device): ctx = tvm.context(device, 0) @@ -41,19 +33,77 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False) + out = topi.argsort(data, axis=-1, is_ascend=False) s = topi.generic.schedule_argsort(out) tvm_data = tvm.nd.array(np_data, ctx) - tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) - f = tvm.build(s, [data, valid_count, out], device) - f(tvm_data, tvm_valid_count, tvm_out) + f = tvm.build(s, [data, out], device) + f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) for device in ['llvm', 'cuda', 'opencl']: check_device(device) +def verify_topk(k, axis, ret_type, is_ascend, dtype): + shape = (20, 100) + data_dtype = "float32" + data = tvm.placeholder(shape, name="data", dtype=data_dtype) + + np_data = np.random.uniform(size=shape).astype(data_dtype) + if is_ascend: + np_indices = np.argsort(np_data, axis=axis) + else: + np_indices = np.argsort(-np_data, axis=axis) + kk = k if k >= 1 else shape[axis] + if axis == 0: + np_indices = np_indices[:kk, :] + np_values = np.zeros(np_indices.shape).astype(data_dtype) + for i in range(shape[1]): + np_values[:, i] = np_data[np_indices[:, i], i] + else: + np_indices = np_indices[:, :kk] + np_values = np.zeros(np_indices.shape).astype(data_dtype) + for i in range(shape[0]): + np_values[i, :] = np_data[i, np_indices[i, :]] + np_indices = np_indices.astype(dtype) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype) + outs = outs if isinstance(outs, list) else [outs] + s = topi.generic.schedule_topk(outs) + tvm_data = tvm.nd.array(np_data, ctx) + tvm_res = [] + for t in outs: + tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, ctx=ctx)) + f = tvm.build(s, [data] + outs, device) + f(tvm_data, *tvm_res) + if ret_type == "both": + tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values) + tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_indices) + elif ret_type == "values": + tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values) + else: + tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices) + + for device in ['llvm', 'cuda', 'opencl']: + check_device(device) + +def test_topk(): + for k in [0, 1, 5]: + for axis in [0, -1, 1]: + for ret_type in ["both", "values", "indices"]: + for dtype in ["int64", "float32"]: + verify_topk(k, axis, ret_type, True, dtype) + verify_topk(k, axis, ret_type, False, dtype) + if __name__ == "__main__": test_argsort() + test_topk() From a79e078f377d6ecc547cad06b09d284c4ce35f8b Mon Sep 17 00:00:00 2001 From: ziheng Date: Tue, 4 Jun 2019 16:56:38 -0700 Subject: [PATCH 075/176] [LANG] Comparison operators support for Imm expressions (#3283) --- python/tvm/expr.py | 10 ++++++++++ python/tvm/relay/quantize/quantize.py | 13 +++++-------- src/relay/op/type_relations.cc | 8 ++++---- tests/python/unittest/test_lang_basic.py | 9 +++++++++ tests/python/unittest/test_lang_container.py | 7 +++++++ 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index a234ac4da53b..b4588e5d971a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -349,6 +349,16 @@ def __init__(self, value): self.__init_handle_by_constructor__( _make.StringImm, value) + def __eq__(self, other): + if isinstance(other, ConstExpr): + return self.value == other.value + return self.value == other + + def __ne__(self, other): + if isinstance(other, ConstExpr): + return self.value != other.value + return self.value != other + @register_node class Cast(Expr): diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 2423e76d308a..66c35b66a498 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,6 @@ from . import _quantize from .. import expr as _expr from .. import ir_pass as _ir_pass -from .. import transform as _transform from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -301,8 +300,6 @@ def optimize(func, params=None): "FoldConstant", "CanonicalizeOps"] - cfg = _transform.build_config(required_pass=opt_passes) - if params: name_dict = {} for arg in func.params: @@ -321,25 +318,25 @@ def optimize(func, params=None): bind_dict[arg] = _expr.const(v) func = _expr.bind(func, bind_dict) - if "SimplifyInference" in cfg.required_pass: + if "SimplifyInference" in opt_passes: func = _ir_pass.infer_type(func) func = _ir_pass.simplify_inference(func) - if "FoldConstant" in cfg.required_pass: + if "FoldConstant" in opt_passes: func = _ir_pass.fold_constant(func) - if "FoldScaleAxis" in cfg.required_pass: + if "FoldScaleAxis" in opt_passes: func = _ir_pass.infer_type(func) func = _ir_pass.backward_fold_scale_axis(func) func = _ir_pass.infer_type(func) func = _ir_pass.forward_fold_scale_axis(func) func = _ir_pass.fold_constant(func) - if "CanonicalizeOps" in cfg.required_pass: + if "CanonicalizeOps" in opt_passes: func = _ir_pass.infer_type(func) func = _ir_pass.canonicalize_ops(func) - if "FoldConstant" in cfg.required_pass: + if "FoldConstant" in opt_passes: func = _ir_pass.fold_constant(func) return func diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 16d09c46dfa2..5b147a489b44 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -108,8 +108,8 @@ bool BroadcastRel(const Array& types, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); - DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] - << ",Out:" << types[2] << std::endl; + // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] + // << ",Out:" << types[2] << std::endl; if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); @@ -126,8 +126,8 @@ bool BroadcastCompRel(const Array& types, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); - DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] - << ",Out:" << types[2] << std::endl; + // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] + // << ",Out:" << types[2] << std::endl; if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 25c0aa14bad7..0ace220ab6e1 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -163,6 +163,14 @@ def test_equality(): d = (c != c) assert not d + +def test_equality_string_imm(): + x = 'a' + y = tvm.make.StringImm(x) + x == y.value + x == y + + if __name__ == "__main__": test_cast() test_attr() @@ -178,3 +186,4 @@ def test_equality(): test_all() test_bitwise() test_equality() + test_equality_string_imm() diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index cce7479a4278..999e379ca48a 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -65,9 +65,16 @@ def test_map_save_load_json(): assert(dd == {"a": 2, "b": 3}) +def test_in_container(): + arr = tvm.convert(['a', 'b', 'c']) + assert 'a' in arr + assert tvm.make.StringImm('a') in arr + assert 'd' not in arr + if __name__ == "__main__": test_str_map() test_array() test_map() test_array_save_load_json() test_map_save_load_json() + test_in_container() From 1dbe83d6c3701afe92f5201bdc1f570af10aef8b Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Tue, 4 Jun 2019 20:32:31 -0700 Subject: [PATCH 076/176] [IR] Try to improve nms and get_valid_count (#3282) * improve nms * add back get_valid_count syncs --- topi/python/topi/cuda/nms.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 911dd84e2f05..460584bc8b78 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -457,15 +457,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): box_indices = ib.buffer_ptr(box_indices) num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") - max_threads = int(math.sqrt( - tvm.target.current_target(allow_none=False).max_num_threads)) + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) - k = bx * max_threads + tx + j = bx * max_threads + tx iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold) top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) @@ -480,22 +480,22 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): nkeep = if_then_else( \ tvm.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) - with ib.for_range(0, nkeep) as j: - with ib.if_scope(k < box_data_length): + with ib.if_scope(j < nkeep): + with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = \ data[(base_idx + sorted_index[i * num_anchors + j] \ * box_data_length + k)] box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])): - with ib.for_range(0, valid_count[i] - nkeep) as j: - with ib.if_scope(k < box_data_length): + with ib.if_scope(j < valid_count[i] - nkeep): + with ib.for_range(0, box_data_length) as k: out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 box_indices[i * num_anchors + (j + nkeep)] = -1 # Apply nms - with ib.for_range(0, valid_count[i]) as j: + with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length with ib.if_scope(out[base_idx + offset_j] >= 0): - with ib.if_scope(k < valid_count[i]): + with ib.for_range(0, valid_count[i]) as k: offset_k = k * box_data_length with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \ tvm.any(force_suppress > 0, id_index < 0, \ @@ -506,35 +506,29 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(iou >= iou_threshold): out[base_idx + offset_k] = -1.0 box_indices[i * num_anchors + k] = -1 - ib.emit(tvm.make.Call(None, 'tvm_storage_sync', - tvm.convert(['shared']), - tvm.expr.Call.Intrinsic, None, 0)) with ib.else_scope(): - with ib.for_range(0, valid_count[i]) as j: + with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length - with ib.if_scope(k < box_data_length): + with ib.for_range(0, box_data_length) as k: out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] box_indices[i * num_anchors + j] = j # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - valid_count[i]) as j: - with ib.if_scope(k < box_data_length): + with ib.if_scope(j < num_anchors - valid_count[i]): + with ib.for_range(0, box_data_length) as k: out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 box_indices[i * num_anchors + j + valid_count[i]] = -1 # Only return max_output_size number of valid boxes num_valid_boxes[0] = 0 with ib.if_scope(max_output_size > 0): - with ib.for_range(0, valid_count[i]) as j: + with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length with ib.if_scope(out[base_idx + offset_j] >= 0): with ib.if_scope(num_valid_boxes[0] == max_output_size): - with ib.if_scope(k < box_data_length): + with ib.for_range(0, box_data_length) as k: out[base_idx + offset_j + k] = -1.0 box_indices[i * num_anchors + j] = -1 with ib.else_scope(): num_valid_boxes[0] += 1 - ib.emit(tvm.make.Call(None, 'tvm_storage_sync', - tvm.convert(['shared']), - tvm.expr.Call.Intrinsic, None, 0)) return ib.get() From cd3248e9abcde27f5ca554fa9942aede7bbb1993 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 5 Jun 2019 09:28:52 -0700 Subject: [PATCH 077/176] [Relay][VM] Fix code generation for packed functions + tuples (#3287) --- src/relay/backend/vm/compiler.cc | 52 ++++++++++++++++++++++++-------- tests/python/relay/test_vm.py | 13 ++++++++ 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 602e92759624..db98a9a9d3fd 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -334,15 +334,42 @@ struct VMCompiler : ExprFunctor { return Instruction::AllocTensor(last_register, dltype, NewRegister()); } - void EmitInvokePrimitive(const Function& func, std::vector args_registers, + void EmitInvokePrimitive(const Function& func, + const std::vector& args_registers, const Type& ret_type) { + std::vector unpacked_arg_regs; std::vector allocs; - size_t return_num = 0; + + // Arity calculation must flatten tuples. + size_t arity = 0; + CHECK_EQ(func->params.size(), args_registers.size()); + for (size_t i = 0; i < func->params.size(); i++) { + auto ty = func->params[i]->checked_type(); + if (ty.as()) { + unpacked_arg_regs.push_back(args_registers[i]); + arity += 1; + } else if (auto tuple_ty = ty.as()) { + for (size_t f = 0; f < tuple_ty->fields.size(); f++) { + const auto& field = tuple_ty->fields[f]; + CHECK(field.as()) + << "only supports non-nested tuples currently " + << "found " << field; + auto dst = NewRegister(); + Emit(Instruction::GetField(args_registers[i], f, dst)); + unpacked_arg_regs.push_back(dst); + } + arity += tuple_ty->fields.size(); + } else { + LOG(FATAL) << "unsupported parameter type " << ty; + } + } + + size_t return_val_count = 0; if (const TensorTypeNode* ttype = ret_type.as()) { // Allocate space for the return tensor. auto alloc = AllocTensorFromType(ttype); allocs.push_back(alloc); - return_num = 1; + return_val_count = 1; } else if (const TupleTypeNode* ttype = ret_type.as()) { std::vector fields_registers; @@ -352,14 +379,15 @@ struct VMCompiler : ExprFunctor { allocs.push_back(AllocTensorFromType(f_type)); fields_registers.push_back(allocs.back().dst); } - return_num = ttype->fields.size(); + return_val_count = ttype->fields.size(); } else { LOG(FATAL) << "Unsupported return value type"; } + arity += return_val_count; for (auto& alloc : allocs) { Emit(alloc); - args_registers.push_back(alloc.dst); + unpacked_arg_regs.push_back(alloc.dst); } // Next generate the invoke instruction. @@ -378,17 +406,15 @@ struct VMCompiler : ExprFunctor { op_index = seen_funcs[cfunc->funcs[0]]; } - // If Tensor, 1 - // If Tuple, size of tuple - size_t arity = func->params.size() + return_num; - Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers)); - if (return_num > 1) { + Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); + + if (return_val_count > 1) { // return value is a tuple, we need to create a tuple std::vector fields_registers; - for (size_t i = func->params.size(); i < arity; ++i) { - fields_registers.push_back(args_registers[i]); + for (size_t i = arity - return_val_count; i < arity; ++i) { + fields_registers.push_back(unpacked_arg_regs[i]); } - Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister())); + Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister())); } } diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index bc99418d5da4..d727e776cbcd 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -49,6 +49,17 @@ def test_split(): res = veval(f, x_data) tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) +def test_split_no_fuse(): + x = relay.var('x', shape=(12,)) + y = relay.split(x, 3, axis=0).astuple() + z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0) + z = relay.annotation.stop_fusion(z) + f = relay.Function([x], z) + x_data = np.random.rand(12,).astype('float32') + res = veval(f, x_data) + tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) + + def test_id(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x) @@ -259,6 +270,8 @@ def test_closure(): test_tuple_second() test_let_scalar() test_let_tensor() + test_split() + test_split_no_fuse() # TODO(@jroesch): restore when match is supported # test_list_constructor() test_closure() From b8bd444dafa29b7c27690a94d172652a7cb2221c Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 5 Jun 2019 09:29:43 -0700 Subject: [PATCH 078/176] Improve error message for custom tflite operators (#3284) --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index bfd63bb0140e..3c3808d09712 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -112,7 +112,7 @@ def get_op_code_str(self, op): op_code_str = self.builtin_op_code[op_code_id] if op_code_id == BuiltinOperator.CUSTOM: # Custom operator - raise NotImplementedError("Not Support Custom Operator Now") + raise NotImplementedError("Custom operators are currently not supported") return op_code_str def get_input_tensors(self, op): From 92c10ec6aae3db2678ec91e3b8ba7cc7639deb8a Mon Sep 17 00:00:00 2001 From: abergeron Date: Wed, 5 Jun 2019 13:14:12 -0400 Subject: [PATCH 079/176] More fixes and tweaks to the cuda conda packages (#3281) --- .gitignore | 4 + conda/Dockerfile.cuda92 | 33 -------- ...Dockerfile.cuda100 => Dockerfile.template} | 10 ++- conda/{build_cuda.sh => Makefile} | 18 ++--- conda/build_cuda.py | 76 +++++++++++++++++++ conda/cross-linux.cmake | 37 --------- conda/nnvm/meta.yaml | 2 +- conda/topi/meta.yaml | 2 +- conda/tvm-libs/build.sh | 25 +----- conda/tvm-libs/meta.yaml | 3 +- conda/tvm/meta.yaml | 2 +- 11 files changed, 104 insertions(+), 108 deletions(-) delete mode 100644 conda/Dockerfile.cuda92 rename conda/{Dockerfile.cuda100 => Dockerfile.template} (72%) rename conda/{build_cuda.sh => Makefile} (66%) mode change 100755 => 100644 create mode 100644 conda/build_cuda.py delete mode 100644 conda/cross-linux.cmake diff --git a/.gitignore b/.gitignore index a7355739cf59..b23847a5e812 100644 --- a/.gitignore +++ b/.gitignore @@ -220,3 +220,7 @@ patched.txt # pipenv file Pipfile Pipfile.lock + +# conda package artifacts +conda/Dockerfile.cuda* +conda/pkg diff --git a/conda/Dockerfile.cuda92 b/conda/Dockerfile.cuda92 deleted file mode 100644 index ad2d8ffca6e0..000000000000 --- a/conda/Dockerfile.cuda92 +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -FROM nvidia/cuda:9.2-devel-centos6 - -RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh && \ - /opt/conda/bin/conda install conda-build conda-verify && \ - /opt/conda/bin/conda clean -ya - -ENV PATH /opt/conda/bin:$PATH -ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 - -WORKDIR /workspace -RUN chmod -R a+w /workspace - -CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: 9.2}' /workspace/conda/tvm-libs diff --git a/conda/Dockerfile.cuda100 b/conda/Dockerfile.template similarity index 72% rename from conda/Dockerfile.cuda100 rename to conda/Dockerfile.template index def8c9ac5d6a..59b9ac96814e 100644 --- a/conda/Dockerfile.cuda100 +++ b/conda/Dockerfile.template @@ -15,7 +15,13 @@ # specific language governing permissions and limitations # under the License. -FROM nvidia/cuda:10.0-devel-centos6 +FROM nvidia/cuda:{{ cuda_version }}-devel-centos6 + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v{{ cudnn_short_version }}/cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -O && \ + tar --no-same-owner -xzf cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -C /usr/local && \ + rm cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz && \ + ldconfig + RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ chmod +x ~/miniconda.sh && \ @@ -30,4 +36,4 @@ ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 WORKDIR /workspace RUN chmod -R a+w /workspace -CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: 10.0}' /workspace/conda/tvm-libs +CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: {{ cuda_version }}}' /workspace/conda/tvm-libs diff --git a/conda/build_cuda.sh b/conda/Makefile old mode 100755 new mode 100644 similarity index 66% rename from conda/build_cuda.sh rename to conda/Makefile index 2f3207e22987..cda546ac73ce --- a/conda/build_cuda.sh +++ b/conda/Makefile @@ -5,22 +5,18 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -#/bin/sh -condadir=`dirname $0` -condadir=`readlink -f $condadir` -srcdir=`dirname $condadir` -docker build -t tvm-cuda100-forge $condadir -f $condadir/Dockerfile.cuda100 -docker run --rm -v $srcdir:/workspace tvm-cuda100-forge -docker build -t tvm-cuda92-forge $condadir -f $condadir/Dockerfile.cuda92 -docker run --rm -v $srcdir:/workspace tvm-cuda92-forge -sudo chown -R `whoami` $condadir/pkg +packages: + conda build tvm-libs + conda build tvm + conda build topi + conda built nnvm diff --git a/conda/build_cuda.py b/conda/build_cuda.py new file mode 100644 index 000000000000..47af6ce4564e --- /dev/null +++ b/conda/build_cuda.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import subprocess + +from jinja2 import Template + +CUDA_VERSIONS = ['10.0', '9.0'] + + +# Make sure that the cudnn version you set here is available +# for all the cuda versions that you want both from nvidia +# and from conda. + +# These two must be in sync +CUDNN_FULL_VERSION = '7.3.1.20' +CUDNN_VERSION = '7.3.1' + + +condadir = os.path.dirname(sys.argv[0]) +condadir = os.path.abspath(condadir) +srcdir = os.path.dirname(condadir) + + +with open(os.path.join(condadir, 'Dockerfile.template')) as f: + docker_template = Template(f.read()) + + +def render_dockerfile(version): + txt = docker_template.render(cuda_version=version, + cudnn_short_version=CUDNN_VERSION, + cudnn_version=CUDNN_FULL_VERSION) + fname = os.path.join(condadir, + 'Dockerfile.cuda' + version.replace('.', '')) + with open(fname, 'w') as f: + f.write(txt) + return fname + + +def build_docker(version): + vv = version.replace('.', '') + fname = render_dockerfile(version) + tagname = f'tvm-cuda{ vv }-forge' + subprocess.run(['docker', 'build', '-t', tagname, + condadir, '-f', fname], check=True) + return tagname + + +def build_pkg(version): + tagname = build_docker(version) + subprocess.run(['docker', 'run', '--rm', '-v', f'{ srcdir }:/workspace', + tagname], check=True) + + +if __name__ == '__main__': + build_versions = CUDA_VERSIONS + if len(sys.argv) > 1: + build_versions = sys.argv[1:] + for version in build_versions: + build_pkg(version) diff --git a/conda/cross-linux.cmake b/conda/cross-linux.cmake deleted file mode 100644 index f84ba8e44a26..000000000000 --- a/conda/cross-linux.cmake +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# this one is important -set(CMAKE_SYSTEM_NAME Linux) -set(CMAKE_PLATFORM Linux) -#this one not so much -set(CMAKE_SYSTEM_VERSION 1) - -# specify the cross compiler -set(CMAKE_C_COMPILER $ENV{CC}) - -# where is the target environment -set(CMAKE_FIND_ROOT_PATH $ENV{PREFIX} $ENV{BUILD_PREFIX}/$ENV{HOST}/sysroot) - -# search for programs in the build host directories -set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) -# for libraries and headers in the target directories -set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) -set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) - -# god-awful hack because it seems to not run correct tests to determine this: -set(__CHAR_UNSIGNED___EXITCODE 1) diff --git a/conda/nnvm/meta.yaml b/conda/nnvm/meta.yaml index 883655f335cb..d948484a61e5 100644 --- a/conda/nnvm/meta.yaml +++ b/conda/nnvm/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 skip: True # [win] requirements: diff --git a/conda/topi/meta.yaml b/conda/topi/meta.yaml index bbba452a6422..f4bc8950d4c4 100644 --- a/conda/topi/meta.yaml +++ b/conda/topi/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 requirements: host: diff --git a/conda/tvm-libs/build.sh b/conda/tvm-libs/build.sh index d4cf2578b570..e0b85910475e 100644 --- a/conda/tvm-libs/build.sh +++ b/conda/tvm-libs/build.sh @@ -16,42 +16,25 @@ # specific language governing permissions and limitations # under the License. -# Fix for OSX build to hide the clang LLVM -rm -f ${BUILD_PREFIX}/bin/llvm-config -rm -rf ${BUILD_PREFIX}/lib/cmake - set -e -if [ -z "$PREFIX" ]; then - PREFIX="$CONDA_PREFIX" -fi - -if [ -z "$cuda" ] || [ "$cuda" == "False" ]; then - CUDA_OPT="" +if [ "$cuda" == "True" ]; then + CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON -DUSE_CUDNN=ON" else - CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON" + CUDA_OPT="" fi if [ "$target_platform" == "osx-64" ]; then # macOS 64 bits METAL_OPT="" # Conda can only target 10.9 for now - TOOLCHAIN_OPT="" else METAL_OPT="" - if [ "$target_platform" == "linux-64" ]; then - # Linux 64 bits - TOOLCHAIN_OPT="-DCMAKE_TOOLCHAIN_FILE=${RECIPE_DIR}/../cross-linux.cmake" - else - # Windows (or 32 bits, which we don't support) - METAL_OPT="" - TOOLCHAIN_OPT="" - fi fi rm -rf build || true mkdir -p build cd build -cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=ON -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" $TOOLCHAIN_OPT .. +cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=$PREFIX/bin/llvm-config -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" .. make -j${CPU_COUNT} VERBOSE=1 make install cd .. diff --git a/conda/tvm-libs/meta.yaml b/conda/tvm-libs/meta.yaml index 5126f5b30359..aad8f251c2a6 100644 --- a/conda/tvm-libs/meta.yaml +++ b/conda/tvm-libs/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 string: cuda{{ cuda_version }}_{{ PKG_BUILDNUM }} # [cuda] requirements: @@ -39,6 +39,7 @@ requirements: - zlib # [linux] run: - {{ pin_compatible('cudatoolkit', lower_bound=cuda_version, max_pin='x.x') }} # [cuda] + - {{ pin_compatible('cudnn', lower_bound='7.3.1', max_pin='x') }} # [cuda] about: home: https://github.com/dmlc/tvm diff --git a/conda/tvm/meta.yaml b/conda/tvm/meta.yaml index 693237ce07c0..221dc7950f75 100644 --- a/conda/tvm/meta.yaml +++ b/conda/tvm/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 requirements: build: From cdabfa90bee53835b43d7dda380432c1a69f8957 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Wed, 5 Jun 2019 10:17:11 -0700 Subject: [PATCH 080/176] [VTA] [Hardware] Chisel implementation (#3258) --- cmake/config.cmake | 3 - cmake/modules/VTA.cmake | 16 +- vta/apps/tsim_example/README.md | 2 +- vta/apps/tsim_example/cmake/modules/hw.cmake | 2 +- vta/hardware/chisel/Makefile | 76 ++++ .../src/main/resources/verilog/VTAHostDPI.v | 2 +- .../chisel/src/main/scala/core/Compute.scala | 201 ++++++++++ .../chisel/src/main/scala/core/Configs.scala | 46 +++ .../chisel/src/main/scala/core/Core.scala | 109 ++++++ .../chisel/src/main/scala/core/Decode.scala | 229 +++++++++++ .../chisel/src/main/scala/core/Fetch.scala | 197 ++++++++++ .../chisel/src/main/scala/core/ISA.scala | 93 +++++ .../chisel/src/main/scala/core/Load.scala | 131 +++++++ .../chisel/src/main/scala/core/LoadUop.scala | 214 ++++++++++ .../src/main/scala/core/Semaphore.scala | 42 ++ .../chisel/src/main/scala/core/Store.scala | 114 ++++++ .../src/main/scala/core/TensorAlu.scala | 295 ++++++++++++++ .../src/main/scala/core/TensorGemm.scala | 364 ++++++++++++++++++ .../src/main/scala/core/TensorLoad.scala | 278 +++++++++++++ .../src/main/scala/core/TensorStore.scala | 224 +++++++++++ .../src/main/scala/core/TensorUtil.scala | 304 +++++++++++++++ .../chisel/src/main/scala/core/package.scala | 23 ++ .../src/main/scala/dpi/VTAHostDPI.scala | 83 ++++ .../chisel/src/main/scala/dpi/VTAMemDPI.scala | 98 +++++ .../src/main/scala/interface/axi/AXI.scala | 312 +++++++++++++++ .../chisel/src/main/scala/shell/Configs.scala | 51 +++ .../src/main/scala/shell/SimShell.scala | 78 ++++ .../chisel/src/main/scala/shell/VCR.scala | 242 ++++++++++++ .../chisel/src/main/scala/shell/VME.scala | 254 ++++++++++++ .../src/main/scala/shell/VTAShell.scala | 57 +++ .../src/main/scala/shell/XilinxShell.scala | 117 ++++++ .../chisel/src/main/scala/test/Test.scala | 33 ++ .../chisel/src/main/scala/util/Config.scala | 104 +++++ .../util/GenericParameterizedBundle.scala | 40 ++ .../chisel/src/main/scala/vta/Configs.scala | 51 +++ vta/hardware/dpi/tsim_device.cc | 10 + vta/include/vta/driver.h | 16 + vta/python/vta/environment.py | 2 +- vta/python/vta/testing/simulator.py | 19 + vta/python/vta/testing/util.py | 5 +- vta/src/runtime.cc | 63 ++- vta/src/tsim/tsim_driver.cc | 179 +++++++++ vta/tests/python/unittest/test_vta_insn.py | 28 +- 43 files changed, 4784 insertions(+), 23 deletions(-) create mode 100644 vta/hardware/chisel/src/main/scala/core/Compute.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Configs.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Core.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Decode.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Fetch.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/ISA.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Load.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/LoadUop.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Semaphore.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/Store.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorAlu.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorGemm.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorLoad.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorStore.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/TensorUtil.scala create mode 100644 vta/hardware/chisel/src/main/scala/core/package.scala create mode 100644 vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/Configs.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/SimShell.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/VCR.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/VME.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/VTAShell.scala create mode 100644 vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala create mode 100644 vta/hardware/chisel/src/main/scala/test/Test.scala create mode 100644 vta/hardware/chisel/src/main/scala/util/Config.scala create mode 100644 vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala create mode 100644 vta/hardware/chisel/src/main/scala/vta/Configs.scala create mode 100644 vta/src/tsim/tsim_driver.cc diff --git a/cmake/config.cmake b/cmake/config.cmake index e7ddb9aba6b8..679de8d7e752 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -135,9 +135,6 @@ set(USE_TENSORRT OFF) # Build ANTLR parser for Relay text format set(USE_ANTLR OFF) -# Build TSIM for VTA -set(USE_VTA_TSIM OFF) - # Whether use Relay debug mode set(USE_RELAY_DEBUG OFF) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 1df6c6676fac..6d5ea000edc2 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -29,8 +29,7 @@ elseif(PYTHON) --use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json) endif() - execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target) - string(STRIP ${__vta_target} VTA_TARGET) + execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE) message(STATUS "Build VTA runtime with target: " ${VTA_TARGET}) @@ -44,6 +43,13 @@ elseif(PYTHON) add_library(vta SHARED ${VTA_RUNTIME_SRCS}) + if(${VTA_TARGET} STREQUAL "tsim") + target_compile_definitions(vta PUBLIC USE_TSIM) + include_directories("vta/include") + file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS}) + endif() + target_include_directories(vta PUBLIC vta/include) foreach(__def ${VTA_DEFINITIONS}) @@ -61,12 +67,6 @@ elseif(PYTHON) target_link_libraries(vta ${__cma_lib}) endif() - if(NOT USE_VTA_TSIM STREQUAL "OFF") - include_directories("vta/include") - file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS}) - endif() - else() message(STATUS "Cannot found python in env, VTA build is skipped..") endif() diff --git a/vta/apps/tsim_example/README.md b/vta/apps/tsim_example/README.md index b557b24ac690..dc06a92f2b0e 100644 --- a/vta/apps/tsim_example/README.md +++ b/vta/apps/tsim_example/README.md @@ -49,7 +49,7 @@ sudo apt install verilator sbt ## Setup in TVM 1. Install `verilator` and `sbt` as described above -2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake +2. Set the VTA TARGET to `tsim` on `/vta/config/vta_config.json` 3. Build tvm ## How to run VTA TSIM examples diff --git a/vta/apps/tsim_example/cmake/modules/hw.cmake b/vta/apps/tsim_example/cmake/modules/hw.cmake index 87dd72b2e626..e016ea03b6fa 100644 --- a/vta/apps/tsim_example/cmake/modules/hw.cmake +++ b/vta/apps/tsim_example/cmake/modules/hw.cmake @@ -124,7 +124,7 @@ else() file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc) add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC}) - set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) + set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) if (NOT TSIM_USE_TRACE STREQUAL "OFF") list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd) else() diff --git a/vta/hardware/chisel/Makefile b/vta/hardware/chisel/Makefile index 65a9ed13c989..7371dd1b3686 100644 --- a/vta/hardware/chisel/Makefile +++ b/vta/hardware/chisel/Makefile @@ -15,5 +15,81 @@ # specific language governing permissions and limitations # under the License. +CONFIG = DefaultF1Config +TOP = VTA +TOP_TEST = Test +BUILD_NAME = build +USE_TRACE = 0 +VTA_LIBNAME = libvta_hw + +config_test = $(TOP_TEST)$(CONFIG) +vta_dir = $(abspath ../../) +tvm_dir = $(abspath ../../../) +verilator_inc_dir = /usr/local/share/verilator/include +verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator +chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel + +verilator_opt = --cc +verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN +verilator_opt += +define+RANDOMIZE_REG_INIT +verilator_opt += +define+RANDOMIZE_MEM_INIT +verilator_opt += --x-assign unique +verilator_opt += --output-split 20000 +verilator_opt += --output-split-cfuncs 20000 +verilator_opt += --top-module ${TOP_TEST} +verilator_opt += -Mdir ${verilator_build_dir} +verilator_opt += -I$(chisel_build_dir) + +cxx_flags = -O2 -Wall -fPIC -shared +cxx_flags += -fvisibility=hidden -std=c++11 +cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST) +cxx_flags += -DVL_PRINTF=printf +cxx_flags += -DVL_USER_FINISH +cxx_flags += -DVM_COVERAGE=0 +cxx_flags += -DVM_SC=0 +cxx_flags += -Wno-sign-compare +cxx_flags += -include V$(TOP_TEST).h +cxx_flags += -I$(verilator_build_dir) +cxx_flags += -I$(verilator_inc_dir) +cxx_flags += -I$(verilator_inc_dir)/vltstd +cxx_flags += -I$(vta_dir)/include +cxx_flags += -I$(tvm_dir)/include +cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include + +cxx_files = $(verilator_inc_dir)/verilated.cpp +cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp +cxx_files += $(wildcard $(verilator_build_dir)/*.cpp) +cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc + +ifneq ($(USE_TRACE), 0) + verilator_opt += --trace + cxx_flags += -DVM_TRACE=1 + cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd + cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp +else + cxx_flags += -DVM_TRACE=0 +endif + +default: lib + +lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so +$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp + g++ $(cxx_flags) $(cxx_files) -o $@ + +verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp +$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v + verilator $(verilator_opt) $< + +verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v +$(chisel_build_dir)/$(TOP).$(CONFIG).v: + sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)' + +verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v +$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v: + sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)' + clean: -rm -rf target project/target project/project + +cleanall: + -rm -rf $(vta_dir)/$(BUILD_NAME) diff --git a/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v b/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v index 02fcf0d779e1..8ab85f6b752c 100644 --- a/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v +++ b/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v @@ -112,7 +112,7 @@ module VTAHostDPI # always_ff @(posedge clock) begin if (__exit == 'd1) begin - $display("[DONE] at cycle:%016d", cycles); + $display("[TSIM] Verilog $finish called at cycle:%016d", cycles); $finish; end end diff --git a/vta/hardware/chisel/src/main/scala/core/Compute.scala b/vta/hardware/chisel/src/main/scala/core/Compute.scala new file mode 100644 index 000000000000..ef56c3d4224e --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Compute.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Compute. + * + * The compute unit is in charge of the following: + * - Loading micro-ops from memory (loadUop module) + * - Loading biases (acc) from memory (tensorAcc module) + * - Compute ALU instructions (tensorAlu module) + * - Compute GEMM instructions (tensorGemm module) + */ +class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val i_post = Vec(2, Input(Bool())) + val o_post = Vec(2, Output(Bool())) + val inst = Flipped(Decoupled(UInt(INST_BITS.W))) + val uop_baddr = Input(UInt(mp.addrBits.W)) + val acc_baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = Vec(2, new VMEReadMaster) + val inp = new TensorMaster(tensorType = "inp") + val wgt = new TensorMaster(tensorType = "wgt") + val out = new TensorMaster(tensorType = "out") + val finish = Output(Bool()) + }) + val sIdle :: sSync :: sExe :: Nil = Enum(3) + val state = RegInit(sIdle) + + val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0))) + + val loadUop = Module(new LoadUop) + val tensorAcc = Module(new TensorLoad(tensorType = "acc")) + val tensorGemm = Module(new TensorGemm) + val tensorAlu = Module(new TensorAlu) + + val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) + + // decode + val dec = Module(new ComputeDecode) + dec.io.inst := inst_q.io.deq.bits + + val inst_type = Cat(dec.io.isFinish, + dec.io.isAlu, + dec.io.isGemm, + dec.io.isLoadAcc, + dec.io.isLoadUop).asUInt + + val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B) + val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B) + val start = snext & sprev + val done = + MuxLookup(inst_type, + false.B, // default + Array( + "h_01".U -> loadUop.io.done, + "h_02".U -> tensorAcc.io.done, + "h_04".U -> tensorGemm.io.done, + "h_08".U -> tensorAlu.io.done, + "h_10".U -> true.B // Finish + ) + ) + + // control + switch (state) { + is (sIdle) { + when (start) { + when (dec.io.isSync) { + state := sSync + } .elsewhen (inst_type.orR) { + state := sExe + } + } + } + is (sSync) { + state := sIdle + } + is (sExe) { + when (done) { + state := sIdle + } + } + } + + // instructions + inst_q.io.enq <> io.inst + inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) + + // uop + loadUop.io.start := state === sIdle & start & dec.io.isLoadUop + loadUop.io.inst := inst_q.io.deq.bits + loadUop.io.baddr := io.uop_baddr + io.vme_rd(0) <> loadUop.io.vme_rd + loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx) + + // acc + tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc + tensorAcc.io.inst := inst_q.io.deq.bits + tensorAcc.io.baddr := io.acc_baddr + tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx) + tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr) + io.vme_rd(1) <> tensorAcc.io.vme_rd + + // gemm + tensorGemm.io.start := state === sIdle & start & dec.io.isGemm + tensorGemm.io.inst := inst_q.io.deq.bits + tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm + tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits + tensorGemm.io.inp <> io.inp + tensorGemm.io.wgt <> io.wgt + tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm + tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits + tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm + tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits + + // alu + tensorAlu.io.start := state === sIdle & start & dec.io.isAlu + tensorAlu.io.inst := inst_q.io.deq.bits + tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu + tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits + tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu + tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits + tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu + tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits + + // out + io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx) + io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr) + + // semaphore + s(0).io.spost := io.i_post(0) + s(1).io.spost := io.i_post(1) + s(0).io.swait := dec.io.pop_prev & (state === sIdle & start) + s(1).io.swait := dec.io.pop_next & (state === sIdle & start) + io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync)) + io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync)) + + // finish + io.finish := state === sExe & done & dec.io.isFinish + + // debug + if (debug) { + // start + when (state === sIdle && start) { + when (dec.io.isSync) { + printf("[Compute] start sync\n") + } .elsewhen (dec.io.isLoadUop) { + printf("[Compute] start load uop\n") + } .elsewhen (dec.io.isLoadAcc) { + printf("[Compute] start load acc\n") + } .elsewhen (dec.io.isGemm) { + printf("[Compute] start gemm\n") + } .elsewhen (dec.io.isAlu) { + printf("[Compute] start alu\n") + } .elsewhen (dec.io.isFinish) { + printf("[Compute] start finish\n") + } + } + // done + when (state === sSync) { + printf("[Compute] done sync\n") + } + when (state === sExe) { + when (done) { + when (dec.io.isLoadUop) { + printf("[Compute] done load uop\n") + } .elsewhen (dec.io.isLoadAcc) { + printf("[Compute] done load acc\n") + } .elsewhen (dec.io.isGemm) { + printf("[Compute] done gemm\n") + } .elsewhen (dec.io.isAlu) { + printf("[Compute] done alu\n") + } .elsewhen (dec.io.isFinish) { + printf("[Compute] done finish\n") + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/Configs.scala b/vta/hardware/chisel/src/main/scala/core/Configs.scala new file mode 100644 index 000000000000..b4e764b120cd --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Configs.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import vta.util.config._ + +/** CoreConfig. + * + * This is one supported configuration for VTA. This file will + * be eventually filled out with class configurations that can be + * mixed/matched with Shell configurations for different backends. + */ +class CoreConfig extends Config((site, here, up) => { + case CoreKey => CoreParams( + batch = 1, + blockOut = 16, + blockIn = 16, + inpBits = 8, + wgtBits = 8, + uopBits = 32, + accBits = 32, + outBits = 8, + uopMemDepth = 2048, + inpMemDepth = 2048, + wgtMemDepth = 1024, + accMemDepth = 2048, + outMemDepth = 2048, + instQueueEntries = 512) +}) diff --git a/vta/hardware/chisel/src/main/scala/core/Core.scala b/vta/hardware/chisel/src/main/scala/core/Core.scala new file mode 100644 index 000000000000..2a2d4e02784f --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Core.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import vta.util.config._ +import vta.shell._ + +/** Core parameters */ +case class CoreParams ( + batch: Int = 1, + blockOut: Int = 16, + blockIn: Int = 16, + inpBits: Int = 8, + wgtBits: Int = 8, + uopBits: Int = 32, + accBits: Int = 32, + outBits: Int = 8, + uopMemDepth: Int = 512, + inpMemDepth: Int = 512, + wgtMemDepth: Int = 512, + accMemDepth: Int = 512, + outMemDepth: Int = 512, + instQueueEntries: Int = 32 +) + +case object CoreKey extends Field[CoreParams] + +/** Core. + * + * The core defines the current VTA architecture by connecting memory and + * compute modules together such as load/store and compute. Most of the + * connections in the core are bulk (<>), and we should try to keep it this + * way, because it is easier to understand what is going on. + * + * Also, the core must be instantiated by a shell using the + * VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces. + * More info about these interfaces and modules can be found in the shell + * directory. + */ +class Core(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val vcr = new VCRClient + val vme = new VMEMaster + }) + val fetch = Module(new Fetch) + val load = Module(new Load) + val compute = Module(new Compute) + val store = Module(new Store) + + // Read(rd) and write(wr) from/to memory (i.e. DRAM) + io.vme.rd(0) <> fetch.io.vme_rd + io.vme.rd(1) <> compute.io.vme_rd(0) + io.vme.rd(2) <> load.io.vme_rd(0) + io.vme.rd(3) <> load.io.vme_rd(1) + io.vme.rd(4) <> compute.io.vme_rd(1) + io.vme.wr(0) <> store.io.vme_wr + + // Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs) + fetch.io.launch := io.vcr.launch + fetch.io.ins_baddr := io.vcr.ptrs(0) + fetch.io.ins_count := io.vcr.vals(0) + + // Load inputs and weights from memory (DRAM) into scratchpads (SRAMs) + load.io.i_post := compute.io.o_post(0) + load.io.inst <> fetch.io.inst.ld + load.io.inp_baddr := io.vcr.ptrs(2) + load.io.wgt_baddr := io.vcr.ptrs(3) + + // The compute module performs the following: + // - Load micro-ops (uops) and accumulations (acc) + // - Compute dense and ALU instructions (tasks) + compute.io.i_post(0) := load.io.o_post + compute.io.i_post(1) := store.io.o_post + compute.io.inst <> fetch.io.inst.co + compute.io.uop_baddr := io.vcr.ptrs(1) + compute.io.acc_baddr := io.vcr.ptrs(4) + compute.io.inp <> load.io.inp + compute.io.wgt <> load.io.wgt + + // The store module performs the following: + // - Writes results from compute into scratchpads (SRAMs) + // - Store results from scratchpads (SRAMs) to memory (DRAM) + store.io.i_post := compute.io.o_post(1) + store.io.inst <> fetch.io.inst.st + store.io.out_baddr := io.vcr.ptrs(5) + store.io.out <> compute.io.out + + // Finish instruction is executed and asserts the VCR finish flag + val finish = RegNext(compute.io.finish) + io.vcr.finish := finish +} diff --git a/vta/hardware/chisel/src/main/scala/core/Decode.scala b/vta/hardware/chisel/src/main/scala/core/Decode.scala new file mode 100644 index 000000000000..f5bf3406347d --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Decode.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ + +import ISA._ + +/** MemDecode. + * + * Decode memory instructions with a Bundle. This is similar to an union, + * therefore order matters when declaring fields. These are the instructions + * decoded with this bundle: + * - LUOP + * - LWGT + * - LINP + * - LACC + * - SOUT + */ +class MemDecode extends Bundle { + val xpad_1 = UInt(M_PAD_BITS.W) + val xpad_0 = UInt(M_PAD_BITS.W) + val ypad_1 = UInt(M_PAD_BITS.W) + val ypad_0 = UInt(M_PAD_BITS.W) + val xstride = UInt(M_STRIDE_BITS.W) + val xsize = UInt(M_SIZE_BITS.W) + val ysize = UInt(M_SIZE_BITS.W) + val empty_0 = UInt(7.W) // derive this + val dram_offset = UInt(M_DRAM_OFFSET_BITS.W) + val sram_offset = UInt(M_SRAM_OFFSET_BITS.W) + val id = UInt(M_ID_BITS.W) + val push_next = Bool() + val push_prev = Bool() + val pop_next = Bool() + val pop_prev = Bool() + val op = UInt(OP_BITS.W) +} + +/** GemmDecode. + * + * Decode GEMM instruction with a Bundle. This is similar to an union, + * therefore order matters when declaring fields. + */ +class GemmDecode extends Bundle { + val wgt_1 = UInt(C_WIDX_BITS.W) + val wgt_0 = UInt(C_WIDX_BITS.W) + val inp_1 = UInt(C_IIDX_BITS.W) + val inp_0 = UInt(C_IIDX_BITS.W) + val acc_1 = UInt(C_AIDX_BITS.W) + val acc_0 = UInt(C_AIDX_BITS.W) + val empty_0 = Bool() + val lp_1 = UInt(C_ITER_BITS.W) + val lp_0 = UInt(C_ITER_BITS.W) + val uop_end = UInt(C_UOP_END_BITS.W) + val uop_begin = UInt(C_UOP_BGN_BITS.W) + val reset = Bool() + val push_next = Bool() + val push_prev = Bool() + val pop_next = Bool() + val pop_prev = Bool() + val op = UInt(OP_BITS.W) +} + +/** AluDecode. + * + * Decode ALU instructions with a Bundle. This is similar to an union, + * therefore order matters when declaring fields. These are the instructions + * decoded with this bundle: + * - VMIN + * - VMAX + * - VADD + * - VSHX + */ +class AluDecode extends Bundle { + val empty_1 = Bool() + val alu_imm = UInt(C_ALU_IMM_BITS.W) + val alu_use_imm = Bool() + val alu_op = UInt(C_ALU_DEC_BITS.W) + val src_1 = UInt(C_IIDX_BITS.W) + val src_0 = UInt(C_IIDX_BITS.W) + val dst_1 = UInt(C_AIDX_BITS.W) + val dst_0 = UInt(C_AIDX_BITS.W) + val empty_0 = Bool() + val lp_1 = UInt(C_ITER_BITS.W) + val lp_0 = UInt(C_ITER_BITS.W) + val uop_end = UInt(C_UOP_END_BITS.W) + val uop_begin = UInt(C_UOP_BGN_BITS.W) + val reset = Bool() + val push_next = Bool() + val push_prev = Bool() + val pop_next = Bool() + val pop_prev = Bool() + val op = UInt(OP_BITS.W) +} + +/** UopDecode. + * + * Decode micro-ops (uops). + */ +class UopDecode extends Bundle { + val u2 = UInt(10.W) + val u1 = UInt(11.W) + val u0 = UInt(11.W) +} + +/** FetchDecode. + * + * Partial decoding for dispatching instructions to Load, Compute, and Store. + */ +class FetchDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val isLoad = Output(Bool()) + val isCompute = Output(Bool()) + val isStore = Output(Bool()) + }) + val csignals = + ListLookup(io.inst, + List(N, OP_X), + Array( + LUOP -> List(Y, OP_G), + LWGT -> List(Y, OP_L), + LINP -> List(Y, OP_L), + LACC -> List(Y, OP_G), + SOUT -> List(Y, OP_S), + GEMM -> List(Y, OP_G), + FNSH -> List(Y, OP_G), + VMIN -> List(Y, OP_G), + VMAX -> List(Y, OP_G), + VADD -> List(Y, OP_G), + VSHX -> List(Y, OP_G) + ) + ) + + val (cs_val_inst: Bool) :: cs_op_type :: Nil = csignals + + io.isLoad := cs_val_inst & cs_op_type === OP_L + io.isCompute := cs_val_inst & cs_op_type === OP_G + io.isStore := cs_val_inst & cs_op_type === OP_S +} + +/** LoadDecode. + * + * Decode dependencies, type and sync for Load module. + */ +class LoadDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val push_next = Output(Bool()) + val pop_next = Output(Bool()) + val isInput = Output(Bool()) + val isWeight = Output(Bool()) + val isSync = Output(Bool()) + }) + val dec = io.inst.asTypeOf(new MemDecode) + io.push_next := dec.push_next + io.pop_next := dec.pop_next + io.isInput := io.inst === LINP & dec.xsize =/= 0.U + io.isWeight := io.inst === LWGT & dec.xsize =/= 0.U + io.isSync := (io.inst === LINP | io.inst === LWGT) & dec.xsize === 0.U +} + +/** ComputeDecode. + * + * Decode dependencies, type and sync for Compute module. + */ +class ComputeDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val push_next = Output(Bool()) + val push_prev = Output(Bool()) + val pop_next = Output(Bool()) + val pop_prev = Output(Bool()) + val isLoadAcc = Output(Bool()) + val isLoadUop = Output(Bool()) + val isSync = Output(Bool()) + val isAlu = Output(Bool()) + val isGemm = Output(Bool()) + val isFinish = Output(Bool()) + }) + val dec = io.inst.asTypeOf(new MemDecode) + io.push_next := dec.push_next + io.push_prev := dec.push_prev + io.pop_next := dec.pop_next + io.pop_prev := dec.pop_prev + io.isLoadAcc := io.inst === LACC & dec.xsize =/= 0.U + io.isLoadUop := io.inst === LUOP & dec.xsize =/= 0.U + io.isSync := (io.inst === LACC | io.inst === LUOP) & dec.xsize === 0.U + io.isAlu := io.inst === VMIN | io.inst === VMAX | io.inst === VADD | io.inst === VSHX + io.isGemm := io.inst === GEMM + io.isFinish := io.inst === FNSH +} + +/** StoreDecode. + * + * Decode dependencies, type and sync for Store module. + */ +class StoreDecode extends Module { + val io = IO(new Bundle { + val inst = Input(UInt(INST_BITS.W)) + val push_prev = Output(Bool()) + val pop_prev = Output(Bool()) + val isStore = Output(Bool()) + val isSync = Output(Bool()) + }) + val dec = io.inst.asTypeOf(new MemDecode) + io.push_prev := dec.push_prev + io.pop_prev := dec.pop_prev + io.isStore := io.inst === SOUT & dec.xsize =/= 0.U + io.isSync := io.inst === SOUT & dec.xsize === 0.U +} diff --git a/vta/hardware/chisel/src/main/scala/core/Fetch.scala b/vta/hardware/chisel/src/main/scala/core/Fetch.scala new file mode 100644 index 000000000000..bcc164a8f623 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Fetch.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Fetch. + * + * The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the + * VTA Memory Engine (VME), and push them into an instruction queue called + * inst_q. Once the instruction queue is full, instructions are dispatched to + * the Load, Compute and Store module queues based on the instruction opcode. + * After draining the queue, the fetch unit checks if there are more instructions + * via the ins_count register which is written by the host. + * + * Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB) + * because we are using a DRAM payload of 8-bytes or half of a VTA instruction. + * This should be configurable for larger payloads, i.e. 64-bytes, which can load + * more than one instruction at the time. Finally, the instruction queue is + * sized (entries_q), depending on the maximum burst allowed in the memory. + */ +class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val launch = Input(Bool()) + val ins_baddr = Input(UInt(mp.addrBits.W)) + val ins_count = Input(UInt(vp.regBits.W)) + val vme_rd = new VMEReadMaster + val inst = new Bundle { + val ld = Decoupled(UInt(INST_BITS.W)) + val co = Decoupled(UInt(INST_BITS.W)) + val st = Decoupled(UInt(INST_BITS.W)) + } + }) + val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word + val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q)) + val dec = Module(new FetchDecode) + + val s1_launch = RegNext(io.launch) + val pulse = io.launch & ~s1_launch + + val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr)) + val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + + val xrem = Reg(chiselTypeOf(io.ins_count)) + val xsize = (io.ins_count << 1.U) - 1.U + val xmax = (1 << mp.lenBits).U + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + + val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (pulse) { + state := sReadCmd + when (xsize < xmax) { + rlen := xsize + ilen := xsize >> 1.U + xrem := 0.U + } .otherwise { + rlen := xmax - 1.U + ilen := (xmax >> 1.U) - 1.U + xrem := xsize - xmax + } + } + } + is (sReadCmd) { + when (io.vme_rd.cmd.ready) { + state := sReadLSB + } + } + is (sReadLSB) { + when (io.vme_rd.data.valid) { + state := sReadMSB + } + } + is (sReadMSB) { + when (io.vme_rd.data.valid) { + when (inst_q.io.count === ilen) { + state := sDrain + } .otherwise { + state := sReadLSB + } + } + } + is (sDrain) { + when (inst_q.io.count === 0.U) { + when (xrem === 0.U) { + state := sIdle + } .elsewhen (xrem < xmax) { + state := sReadCmd + rlen := xrem + ilen := xrem >> 1.U + xrem := 0.U + } .otherwise { + state := sReadCmd + rlen := xmax - 1.U + ilen := (xmax >> 1.U) - 1.U + xrem := xrem - xmax + } + } + } + } + + // read instructions from dram + when (state === sIdle) { + raddr := io.ins_baddr + } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) { + raddr := raddr + xmax_bytes + } + + io.vme_rd.cmd.valid := state === sReadCmd + io.vme_rd.cmd.bits.addr := raddr + io.vme_rd.cmd.bits.len := rlen + + io.vme_rd.data.ready := inst_q.io.enq.ready + + val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits)) + val msb = io.vme_rd.data.bits + val inst = Cat(msb, lsb) + + when (state === sReadLSB) { lsb := io.vme_rd.data.bits } + + inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB + inst_q.io.enq.bits := inst + + // decode + dec.io.inst := inst_q.io.deq.bits + + // instruction queues + io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain + io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain + io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain + + io.inst.ld.bits := inst_q.io.deq.bits + io.inst.co.bits := inst_q.io.deq.bits + io.inst.st.bits := inst_q.io.deq.bits + + // check if selected queue is ready + val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt + val deq_ready = + MuxLookup(deq_sel, + false.B, // default + Array( + "h_01".U -> io.inst.ld.ready, + "h_02".U -> io.inst.st.ready, + "h_04".U -> io.inst.co.ready + ) + ) + + // dequeue instruction + inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain + + + // debug + if (debug) { + when (state === sIdle && pulse) { + printf("[Fetch] Launch\n") + } + // instruction + when (inst_q.io.deq.fire()) { + when (dec.io.isLoad) { + printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits) + } + when (dec.io.isCompute) { + printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits) + } + when (dec.io.isStore) { + printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits) + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/ISA.scala b/vta/hardware/chisel/src/main/scala/core/ISA.scala new file mode 100644 index 000000000000..c3bf6097adcd --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/ISA.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ + +/** ISAConstants. + * + * These constants are used for decoding (parsing) fields on instructions. + */ +trait ISAConstants +{ + val INST_BITS = 128 + + val OP_BITS = 3 + + val M_DEP_BITS = 4 + val M_ID_BITS = 2 + val M_SRAM_OFFSET_BITS = 16 + val M_DRAM_OFFSET_BITS = 32 + val M_SIZE_BITS = 16 + val M_STRIDE_BITS = 16 + val M_PAD_BITS = 4 + + val C_UOP_BGN_BITS = 13 + val C_UOP_END_BITS = 14 + val C_ITER_BITS = 14 + val C_AIDX_BITS = 11 + val C_IIDX_BITS = 11 + val C_WIDX_BITS = 10 + val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction + val C_ALU_OP_BITS = 3 + val C_ALU_IMM_BITS = 16 + + val Y = true.B + val N = false.B + + val OP_L = 0.asUInt(OP_BITS.W) + val OP_S = 1.asUInt(OP_BITS.W) + val OP_G = 2.asUInt(OP_BITS.W) + val OP_F = 3.asUInt(OP_BITS.W) + val OP_A = 4.asUInt(OP_BITS.W) + val OP_X = 5.asUInt(OP_BITS.W) + + val ALU_OP_NUM = 5 + val ALU_OP = Enum(ALU_OP_NUM) + + val M_ID_U = 0.asUInt(M_ID_BITS.W) + val M_ID_W = 1.asUInt(M_ID_BITS.W) + val M_ID_I = 2.asUInt(M_ID_BITS.W) + val M_ID_A = 3.asUInt(M_ID_BITS.W) +} + +/** ISA. + * + * This is the VTA ISA, here we specify the cares and dont-cares that makes + * decoding easier. Since instructions are quite long 128-bit, we could generate + * these based on ISAConstants. + * + * FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler + * TODO: Add VXOR to clear accumulator + */ +object ISA { + def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000") + def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000") + def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000") + def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000") + def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001") + def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010") + def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011") +} diff --git a/vta/hardware/chisel/src/main/scala/core/Load.scala b/vta/hardware/chisel/src/main/scala/core/Load.scala new file mode 100644 index 000000000000..64795139aa4e --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Load.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Load. + * + * Load inputs and weights from memory (DRAM) into scratchpads (SRAMs). + * This module instantiate the TensorLoad unit which is in charge of + * loading 1D and 2D tensors to scratchpads, so it can be used by + * other modules such as Compute. + */ +class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val i_post = Input(Bool()) + val o_post = Output(Bool()) + val inst = Flipped(Decoupled(UInt(INST_BITS.W))) + val inp_baddr = Input(UInt(mp.addrBits.W)) + val wgt_baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = Vec(2, new VMEReadMaster) + val inp = new TensorClient(tensorType = "inp") + val wgt = new TensorClient(tensorType = "wgt") + }) + val sIdle :: sSync :: sExe :: Nil = Enum(3) + val state = RegInit(sIdle) + + val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0)) + val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) + + val dec = Module(new LoadDecode) + dec.io.inst := inst_q.io.deq.bits + + val tensorType = Seq("inp", "wgt") + val tensorDec = Seq(dec.io.isInput, dec.io.isWeight) + val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i)))) + + val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B) + val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done) + + // control + switch (state) { + is (sIdle) { + when (start) { + when (dec.io.isSync) { + state := sSync + } .elsewhen (dec.io.isInput || dec.io.isWeight) { + state := sExe + } + } + } + is (sSync) { + state := sIdle + } + is (sExe) { + when (done) { + state := sIdle + } + } + } + + // instructions + inst_q.io.enq <> io.inst + inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) + + // load tensor + // [0] input (inp) + // [1] weight (wgt) + val ptr = Seq(io.inp_baddr, io.wgt_baddr) + val tsor = Seq(io.inp, io.wgt) + for (i <- 0 until 2) { + tensorLoad(i).io.start := state === sIdle & start & tensorDec(i) + tensorLoad(i).io.inst := inst_q.io.deq.bits + tensorLoad(i).io.baddr := ptr(i) + tensorLoad(i).io.tensor <> tsor(i) + io.vme_rd(i) <> tensorLoad(i).io.vme_rd + } + + // semaphore + s.io.spost := io.i_post + s.io.swait := dec.io.pop_next & (state === sIdle & start) + io.o_post := dec.io.push_next & ((state === sExe & done) | (state === sSync)) + + // debug + if (debug) { + // start + when (state === sIdle && start) { + when (dec.io.isSync) { + printf("[Load] start sync\n") + } .elsewhen (dec.io.isInput) { + printf("[Load] start input\n") + } .elsewhen (dec.io.isWeight) { + printf("[Load] start weight\n") + } + } + // done + when (state === sSync) { + printf("[Load] done sync\n") + } + when (state === sExe) { + when (done) { + when (dec.io.isInput) { + printf("[Load] done input\n") + } .elsewhen (dec.io.isWeight) { + printf("[Load] done weight\n") + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/LoadUop.scala b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala new file mode 100644 index 000000000000..07296523b254 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** UopMaster. + * + * Uop interface used by a master module, i.e. TensorAlu or TensorGemm, + * to request a micro-op (uop) from the uop-scratchpad. The index (idx) is + * used as an address to find the uop in the uop-scratchpad. + */ +class UopMaster(implicit p: Parameters) extends Bundle { + val addrBits = log2Ceil(p(CoreKey).uopMemDepth) + val idx = ValidIO(UInt(addrBits.W)) + val data = Flipped(ValidIO(new UopDecode)) + override def cloneType = new UopMaster().asInstanceOf[this.type] +} + +/** UopClient. + * + * Uop interface used by a client module, i.e. LoadUop, to receive + * a request from a master module, i.e. TensorAlu or TensorGemm. + * The index (idx) is used as an address to find the uop in the uop-scratchpad. + */ +class UopClient(implicit p: Parameters) extends Bundle { + val addrBits = log2Ceil(p(CoreKey).uopMemDepth) + val idx = Flipped(ValidIO(UInt(addrBits.W))) + val data = ValidIO(new UopDecode) + override def cloneType = new UopClient().asInstanceOf[this.type] +} + +/** LoadUop. + * + * Load micro-ops (uops) from memory, i.e. DRAM, and store them in the + * uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in + * group of 2 given the fact that the DRAM payload is 8-bytes. This module + * should be modified later on to support different DRAM sizes efficiently. + */ +class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = new VMEReadMaster + val uop = new UopClient + }) + val numUop = 2 // store two uops per sram word + val uopBits = p(CoreKey).uopBits + val uopDepth = p(CoreKey).uopMemDepth / numUop + + val dec = io.inst.asTypeOf(new MemDecode) + val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr)) + val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) + val xrem = Reg(chiselTypeOf(dec.xsize)) + val xsize = dec.xsize(0) + (dec.xsize >> log2Ceil(numUop)) - 1.U + val xmax = (1 << mp.lenBits).U + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + + val offsetIsEven = (dec.sram_offset % 2.U) === 0.U + val sizeIsEven = (dec.xsize % 2.U) === 0.U + + val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (io.start) { + state := sReadCmd + when (xsize < xmax) { + xlen := xsize + xrem := 0.U + } .otherwise { + xlen := xmax - 1.U + xrem := xsize - xmax + } + } + } + is (sReadCmd) { + when (io.vme_rd.cmd.ready) { + state := sReadData + } + } + is (sReadData) { + when (io.vme_rd.data.valid) { + when(xcnt === xlen) { + when (xrem === 0.U) { + state := sIdle + } .elsewhen (xrem < xmax) { + state := sReadCmd + xlen := xrem + xrem := 0.U + } .otherwise { + state := sReadCmd + xlen := xmax - 1.U + xrem := xrem - xmax + } + } + } + } + } + + // read-from-dram + when (state === sIdle) { + when (offsetIsEven) { + raddr := io.baddr + dec.dram_offset + } .otherwise { + raddr := io.baddr + dec.dram_offset - 4.U + } + } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) { + raddr := raddr + xmax_bytes + } + + io.vme_rd.cmd.valid := state === sReadCmd + io.vme_rd.cmd.bits.addr := raddr + io.vme_rd.cmd.bits.len := xlen + + io.vme_rd.data.ready := state === sReadData + + when (state =/= sReadData) { + xcnt := 0.U + } .elsewhen (io.vme_rd.data.fire()) { + xcnt := xcnt + 1.U + } + + val waddr = Reg(UInt(log2Ceil(uopDepth).W)) + when (state === sIdle) { + waddr := dec.sram_offset >> log2Ceil(numUop) + } .elsewhen (io.vme_rd.data.fire()) { + waddr := waddr + 1.U + } + + val wdata = Wire(Vec(numUop, UInt(uopBits.W))) + val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata)) + val wmask = Reg(Vec(numUop, Bool())) + + when (offsetIsEven) { + when (sizeIsEven) { + wmask := "b_11".U.asTypeOf(wmask) + } .elsewhen (io.vme_rd.cmd.fire()) { + when (dec.xsize === 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + } .otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } + } .elsewhen (io.vme_rd.data.fire()) { + when (xcnt === xlen - 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + } .otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } + } + } .otherwise { + when (io.vme_rd.cmd.fire()) { + wmask := "b_10".U.asTypeOf(wmask) + } .elsewhen (io.vme_rd.data.fire()) { + when (sizeIsEven && xcnt === xlen - 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + } .otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } + } + } + + wdata := io.vme_rd.data.bits.asTypeOf(wdata) + when (io.vme_rd.data.fire()) { + mem.write(waddr, wdata, wmask) + } + + // read-from-sram + io.uop.data.valid := RegNext(io.uop.idx.valid) + + val sIdx = io.uop.idx.bits % numUop.U + val rIdx = io.uop.idx.bits >> log2Ceil(numUop) + val memRead = mem.read(rIdx, io.uop.idx.valid) + val sWord = memRead.asUInt.asTypeOf(wdata) + val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits) + + io.uop.data.bits <> sUop + + // done + io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U + + // debug + if (debug) { + when (io.vme_rd.cmd.fire()) { + printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem) + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/Semaphore.scala b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala new file mode 100644 index 000000000000..06df51e20e27 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ + +/** Semaphore. + * + * This semaphore is used instead of push/pop fifo, used in the initial + * version of VTA. This semaphore is incremented (spost) or decremented (swait) + * depending on the push and pop fields on instructions to prevent RAW and WAR + * hazards. + */ +class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module { + val io = IO(new Bundle { + val spost = Input(Bool()) + val swait = Input(Bool()) + val sready = Output(Bool()) + }) + val cnt = RegInit(counterInitValue.U(counterBits.W)) + when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U } + when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U } + io.sready := cnt =/= 0.U +} diff --git a/vta/hardware/chisel/src/main/scala/core/Store.scala b/vta/hardware/chisel/src/main/scala/core/Store.scala new file mode 100644 index 000000000000..5d89871e65be --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/Store.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** Store. + * + * Store results back to memory (DRAM) from scratchpads (SRAMs). + * This module instantiate the TensorStore unit which is in charge + * of storing 1D and 2D tensors to main memory. + */ +class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val i_post = Input(Bool()) + val o_post = Output(Bool()) + val inst = Flipped(Decoupled(UInt(INST_BITS.W))) + val out_baddr = Input(UInt(mp.addrBits.W)) + val vme_wr = new VMEWriteMaster + val out = new TensorClient(tensorType = "out") + }) + val sIdle :: sSync :: sExe :: Nil = Enum(3) + val state = RegInit(sIdle) + + val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0)) + val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) + + val dec = Module(new StoreDecode) + dec.io.inst := inst_q.io.deq.bits + + val tensorStore = Module(new TensorStore(tensorType = "out")) + + val start = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s.io.sready, true.B) + val done = tensorStore.io.done + + // control + switch (state) { + is (sIdle) { + when (start) { + when (dec.io.isSync) { + state := sSync + } .elsewhen (dec.io.isStore) { + state := sExe + } + } + } + is (sSync) { + state := sIdle + } + is (sExe) { + when (done) { + state := sIdle + } + } + } + + // instructions + inst_q.io.enq <> io.inst + inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) + + // store + tensorStore.io.start := state === sIdle & start & dec.io.isStore + tensorStore.io.inst := inst_q.io.deq.bits + tensorStore.io.baddr := io.out_baddr + io.vme_wr <> tensorStore.io.vme_wr + tensorStore.io.tensor <> io.out + + // semaphore + s.io.spost := io.i_post + s.io.swait := dec.io.pop_prev & (state === sIdle & start) + io.o_post := dec.io.push_prev & ((state === sExe & done) | (state === sSync)) + + // debug + if (debug) { + // start + when (state === sIdle && start) { + when (dec.io.isSync) { + printf("[Store] start sync\n") + } .elsewhen (dec.io.isStore) { + printf("[Store] start\n") + } + } + // done + when (state === sSync) { + printf("[Store] done sync\n") + } + when (state === sExe) { + when (done) { + printf("[Store] done\n") + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala new file mode 100644 index 000000000000..7f429be7249f --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ + +/** ALU datapath */ +class Alu(implicit p: Parameters) extends Module { + val aluBits = p(CoreKey).accBits + val io = IO(new Bundle { + val opcode = Input(UInt(C_ALU_OP_BITS.W)) + val a = Input(SInt(aluBits.W)) + val b = Input(SInt(aluBits.W)) + val y = Output(SInt(aluBits.W)) + }) + + // FIXME: the following three will change once we support properly SHR and SHL + val ub = io.b.asUInt + val width = log2Ceil(aluBits) + val m = ~ub(width - 1, 0) + 1.U + + val n = ub(width - 1, 0) + val fop = Seq(Mux(io.a < io.b, io.a, io.b), + Mux(io.a < io.b, io.b, io.a), + io.a + io.b, + io.a >> n, + io.a << m) + + val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i)) + io.y := MuxLookup(io.opcode, io.a, opmux) +} + +/** Pipelined ALU */ +class AluReg(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val opcode = Input(UInt(C_ALU_OP_BITS.W)) + val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W))) + val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W))) + val y = ValidIO(UInt(p(CoreKey).accBits.W)) + }) + val alu = Module(new Alu) + val rA = RegEnable(io.a.bits, io.a.valid) + val rB = RegEnable(io.b.bits, io.b.valid) + val valid = RegNext(io.b.valid) + + alu.io.opcode := io.opcode + + // register input + alu.io.a := rA.asSInt + alu.io.b := rB.asSInt + + // output + io.y.valid := valid + io.y.bits := alu.io.y.asUInt +} + +/** Vector of pipeline ALUs */ +class AluVector(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val opcode = Input(UInt(C_ALU_OP_BITS.W)) + val acc_a = new TensorMasterData(tensorType = "acc") + val acc_b = new TensorMasterData(tensorType = "acc") + val acc_y = new TensorClientData(tensorType = "acc") + val out = new TensorClientData(tensorType = "out") + }) + val blockOut = p(CoreKey).blockOut + val f = Seq.fill(blockOut)(Module(new AluReg)) + val valid = Wire(Vec(blockOut, Bool())) + for (i <- 0 until blockOut) { + f(i).io.opcode := io.opcode + f(i).io.a.valid := io.acc_a.data.valid + f(i).io.a.bits := io.acc_a.data.bits(0)(i) + f(i).io.b.valid := io.acc_b.data.valid + f(i).io.b.bits := io.acc_b.data.bits(0)(i) + valid(i) := f(i).io.y.valid + io.acc_y.data.bits(0)(i) := f(i).io.y.bits + io.out.data.bits(0)(i) := f(i).io.y.bits + } + io.acc_y.data.valid := valid.asUInt.andR + io.out.data.valid := valid.asUInt.andR +} + +/** TensorAlu. + * + * This unit instantiate the ALU vector unit (AluVector) and go over the + * micro-ops (uops) which are used to read the source operands (vectors) + * from the acc-scratchpad and then they are written back the same + * acc-scratchpad. + */ +class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val uop = new UopMaster + val acc = new TensorMaster(tensorType = "acc") + val out = new TensorMaster(tensorType = "out") + }) + val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6) + val state = RegInit(sIdle) + val alu = Module(new AluVector) + val dec = io.inst.asTypeOf(new AluDecode) + val uop_idx = Reg(chiselTypeOf(dec.uop_end)) + val uop_end = dec.uop_end + val uop_dst = Reg(chiselTypeOf(dec.uop_end)) + val uop_src = Reg(chiselTypeOf(dec.uop_end)) + val cnt_o = Reg(chiselTypeOf(dec.lp_0)) + val dst_o = Reg(chiselTypeOf(dec.uop_end)) + val src_o = Reg(chiselTypeOf(dec.uop_end)) + val cnt_i = Reg(chiselTypeOf(dec.lp_1)) + val dst_i = Reg(chiselTypeOf(dec.uop_end)) + val src_i = Reg(chiselTypeOf(dec.uop_end)) + val done = + state === sExe & + alu.io.out.data.valid & + (cnt_o === dec.lp_0 - 1.U) & + (cnt_i === dec.lp_1 - 1.U) & + (uop_idx === uop_end - 1.U) + + switch (state) { + is (sIdle) { + when (io.start) { + state := sReadUop + } + } + is (sReadUop) { + state := sComputeIdx + } + is (sComputeIdx) { + state := sReadTensorA + } + is (sReadTensorA) { + state := sReadTensorB + } + is (sReadTensorB) { + state := sExe + } + is (sExe) { + when (alu.io.out.data.valid) { + when ((cnt_o === dec.lp_0 - 1.U) && + (cnt_i === dec.lp_1 - 1.U) && + (uop_idx === uop_end - 1.U)) { + state := sIdle + } .otherwise { + state := sReadUop + } + } + } + } + + when (state === sIdle || + (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U)) { + uop_idx := dec.uop_begin + } .elsewhen (state === sExe && alu.io.out.data.valid) { + uop_idx := uop_idx + 1.U + } + + when (state === sIdle) { + cnt_o := 0.U + dst_o := 0.U + src_o := 0.U + } .elsewhen (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U && + cnt_i === dec.lp_1 - 1.U) { + cnt_o := cnt_o + 1.U + dst_o := dst_o + dec.dst_0 + src_o := src_o + dec.src_0 + } + + when (state === sIdle) { + cnt_i := 0.U + dst_i := 0.U + src_i := 0.U + } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { + cnt_i := 0.U + dst_i := dst_o + src_i := src_o + } .elsewhen (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U) { + cnt_i := cnt_i + 1.U + dst_i := dst_i + dec.dst_1 + src_i := src_i + dec.src_1 + } + + when (state === sComputeIdx && io.uop.data.valid) { + uop_dst := io.uop.data.bits.u0 + dst_i + uop_src := io.uop.data.bits.u1 + src_i + } + + // uop + io.uop.idx.valid := state === sReadUop + io.uop.idx.bits := uop_idx + + // acc_i + io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm) + io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src) + + // imm + val tensorImm = Wire(new TensorClientData(tensorType = "acc")) + tensorImm.data.valid := state === sReadTensorB + tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } } + + // alu + val isSHR = dec.alu_op === ALU_OP(3) + val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1) + val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op)) + alu.io.opcode := fixme_alu_op + alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB + alu.io.acc_a.data.bits <> io.acc.rd.data.bits + alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe) + alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits) + + // acc_o + io.acc.wr.valid := alu.io.acc_y.data.valid + io.acc.wr.bits.idx := uop_dst + io.acc.wr.bits.data <> alu.io.acc_y.data.bits + + // out + io.out.wr.valid := alu.io.out.data.valid + io.out.wr.bits.idx := uop_dst + io.out.wr.bits.data <> alu.io.out.data.bits + io.out.tieoffRead() // write-only + + io.done := done + + if (debug) { + + when (state === sReadUop) { + printf("[TensorAlu] [uop] idx:%x\n", uop_idx) + } + + when (state === sReadTensorA) { + printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src) + } + + when (state === sIdle && io.start) { + printf(p"[TensorAlu] decode:$dec\n") + } + + alu.io.acc_a.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.acc_a.data.valid) { + printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem) + } + } + } + + alu.io.acc_b.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.acc_b.data.valid) { + printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem) + } + } + } + + alu.io.acc_y.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.acc_y.data.valid) { + printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem) + } + } + } + + alu.io.out.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (alu.io.out.data.valid) { + printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem) + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala new file mode 100644 index 000000000000..2dd8c33aea33 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import chisel3.experimental._ +import vta.util.config._ +import scala.math.pow + +/** Pipelined multiply and accumulate */ +class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module { + require (cBits >= dataBits * 2) + require (outBits >= dataBits * 2) + val io = IO(new Bundle { + val a = Input(SInt(dataBits.W)) + val b = Input(SInt(dataBits.W)) + val c = Input(SInt(cBits.W)) + val y = Output(SInt(outBits.W)) + }) + val mult = Wire(SInt(cBits.W)) + val add = Wire(SInt(outBits.W)) + val rA = RegNext(io.a) + val rB = RegNext(io.b) + val rC = RegNext(io.c) + mult := rA * rB + add := rC + mult + io.y := add +} + +/** Pipelined adder */ +class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module { + require (outBits >= dataBits) + val io = IO(new Bundle { + val a = Input(SInt(dataBits.W)) + val b = Input(SInt(dataBits.W)) + val y = Output(SInt(outBits.W)) + }) + val add = Wire(SInt(outBits.W)) + val rA = RegNext(io.a) + val rB = RegNext(io.b) + add := rA + rB + io.y := add +} + +/** Pipelined DotProduct based on MAC and Adder */ +class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module { + val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" + require(size >= 4 && isPow2(size), errMsg) + val b = dataBits * 2 + val outBits = b + log2Ceil(size) + 1 + val io = IO(new Bundle { + val a = Input(Vec(size, SInt(dataBits.W))) + val b = Input(Vec(size, SInt(dataBits.W))) + val y = Output(SInt(outBits.W)) + }) + val p = log2Ceil(size/2) + val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt) + val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i))) + val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i))) + val m = Seq.tabulate(2)(i => + Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1))) + ) + val a = Seq.tabulate(p)(i => + Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3))) + ) + + for (i <- 0 until log2Ceil(size)) { + for (j <- 0 until s(i)) { + if (i == 0) { + m(i)(j).io.a := io.a(j) + m(i)(j).io.b := io.b(j) + m(i)(j).io.c := 0.S + m(i + 1)(j).io.a := da(j) + m(i + 1)(j).io.b := db(j) + m(i + 1)(j).io.c := m(i)(j).io.y + } else if (i == 1) { + a(i - 1)(j).io.a := m(i)(2*j).io.y + a(i - 1)(j).io.b := m(i)(2*j + 1).io.y + } else { + a(i - 1)(j).io.a := a(i - 2)(2*j).io.y + a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y + } + } + } + io.y := a(p-1)(0).io.y +} + +/** Perform matric-vector-multiplication based on DotProduct */ +class MatrixVectorCore(implicit p: Parameters) extends Module { + val accBits = p(CoreKey).accBits + val size = p(CoreKey).blockOut + val dataBits = p(CoreKey).inpBits + val io = IO(new Bundle{ + val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr + val inp = new TensorMasterData(tensorType = "inp") + val wgt = new TensorMasterData(tensorType = "wgt") + val acc_i = new TensorMasterData(tensorType = "acc") + val acc_o = new TensorClientData(tensorType = "acc") + val out = new TensorClientData(tensorType = "out") + }) + val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size))) + val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) + val add = Seq.fill(size)(Wire(SInt(accBits.W))) + val vld = Wire(Vec(size, Bool())) + + for (i <- 0 until size) { + acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset + acc(i).io.enq.bits := io.acc_i.data.bits(0)(i) + for (j <- 0 until size) { + dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt + dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt + } + add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y + io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt) + io.out.data.bits(0)(i) := add(i).asUInt + vld(i) := acc(i).io.deq.valid + } + io.acc_o.data.valid := vld.asUInt.andR | io.reset + io.out.data.valid := vld.asUInt.andR +} + +/** TensorGemm. + * + * This unit instantiate the MatrixVectorCore and go over the + * micro-ops (uops) which are used to read inputs, weights and biases, + * and writes results back to the acc and out scratchpads. + * + * Also, the TensorGemm uses the reset field in the Gemm instruction to + * clear or zero-out the acc-scratchpad locations based on the micro-ops. + */ +class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val uop = new UopMaster + val inp = new TensorMaster(tensorType = "inp") + val wgt = new TensorMaster(tensorType = "wgt") + val acc = new TensorMaster(tensorType = "acc") + val out = new TensorMaster(tensorType = "out") + }) + val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6) + val state = RegInit(sIdle) + val mvc = Module(new MatrixVectorCore) + val dec = io.inst.asTypeOf(new GemmDecode) + val uop_idx = Reg(chiselTypeOf(dec.uop_end)) + val uop_end = dec.uop_end + val uop_acc = Reg(chiselTypeOf(dec.uop_end)) + val uop_inp = Reg(chiselTypeOf(dec.uop_end)) + val uop_wgt = Reg(chiselTypeOf(dec.uop_end)) + val cnt_o = Reg(chiselTypeOf(dec.lp_0)) + val acc_o = Reg(chiselTypeOf(dec.uop_end)) + val inp_o = Reg(chiselTypeOf(dec.uop_end)) + val wgt_o = Reg(chiselTypeOf(dec.uop_end)) + val cnt_i = Reg(chiselTypeOf(dec.lp_1)) + val acc_i = Reg(chiselTypeOf(dec.uop_end)) + val inp_i = Reg(chiselTypeOf(dec.uop_end)) + val wgt_i = Reg(chiselTypeOf(dec.uop_end)) + val pBits = log2Ceil(p(CoreKey).blockOut) + 1 + val inflight = Reg(UInt(pBits.W)) + val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits)) + val done = inflight === 0.U & + ((state === sExe & + cnt_o === dec.lp_0 - 1.U & + cnt_i === dec.lp_1 - 1.U & + uop_idx === uop_end - 1.U & + inflight === 0.U) | + state === sWait) + + switch (state) { + is (sIdle) { + when (io.start) { + state := sReadUop + } + } + is (sReadUop) { + state := sComputeIdx + } + is (sComputeIdx) { + state := sReadTensor + } + is (sReadTensor) { + state := sExe + } + is (sExe) { + when ((cnt_o === dec.lp_0 - 1.U) && + (cnt_i === dec.lp_1 - 1.U) && + (uop_idx === uop_end - 1.U)) { + when (inflight =/= 0.U) { + state := sWait + } .otherwise { + state := sIdle + } + } .otherwise { + state := sReadUop + } + } + is (sWait) { + when (inflight === 0.U) { + state := sIdle + } + } + } + + when (state === sIdle) { + inflight := 0.U + } .elsewhen (!dec.reset) { + when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check + inflight := inflight + 1.U + } .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check + inflight := inflight - 1.U + } + } + + when (state === sIdle || + (state === sExe && + uop_idx === uop_end - 1.U)) { + uop_idx := dec.uop_begin + } .elsewhen (state === sExe) { + uop_idx := uop_idx + 1.U + } + + when (state === sIdle) { + cnt_o := 0.U + acc_o := 0.U + inp_o := 0.U + wgt_o := 0.U + } .elsewhen (state === sExe && + uop_idx === uop_end - 1.U && + cnt_i === dec.lp_1 - 1.U) { + cnt_o := cnt_o + 1.U + acc_o := acc_o + dec.acc_0 + inp_o := inp_o + dec.inp_0 + wgt_o := wgt_o + dec.wgt_0 + } + + when (state === sIdle) { + cnt_i := 0.U + acc_i := 0.U + inp_i := 0.U + wgt_i := 0.U + } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { + cnt_i := 0.U + acc_i := acc_o + inp_i := inp_o + wgt_i := wgt_o + } .elsewhen (state === sExe && + uop_idx === uop_end - 1.U) { + cnt_i := cnt_i + 1.U + acc_i := acc_i + dec.acc_1 + inp_i := inp_i + dec.inp_1 + wgt_i := wgt_i + dec.wgt_1 + } + + when (state === sComputeIdx && io.uop.data.valid) { + uop_acc := io.uop.data.bits.u0 + acc_i + uop_inp := io.uop.data.bits.u1 + inp_i + uop_wgt := io.uop.data.bits.u2 + wgt_i + } + + wrpipe.io.enq.valid := state === sExe & ~dec.reset + wrpipe.io.enq.bits := uop_acc + + // uop + io.uop.idx.valid := state === sReadUop + io.uop.idx.bits := uop_idx + + // inp + io.inp.rd.idx.valid := state === sReadTensor + io.inp.rd.idx.bits := uop_inp + io.inp.tieoffWrite() // read-only + + // wgt + io.wgt.rd.idx.valid := state === sReadTensor + io.wgt.rd.idx.bits := uop_wgt + io.wgt.tieoffWrite() // read-only + + // acc_i + io.acc.rd.idx.valid := state === sReadTensor + io.acc.rd.idx.bits := uop_acc + + // mvc + mvc.io.reset := dec.reset & state === sExe + mvc.io.inp.data <> io.inp.rd.data + mvc.io.wgt.data <> io.wgt.rd.data + mvc.io.acc_i.data <> io.acc.rd.data + + // acc_o + io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid) + io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits) + io.acc.wr.bits.data <> mvc.io.acc_o.data.bits + + // out + io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid + io.out.wr.bits.idx := wrpipe.io.deq.bits + io.out.wr.bits.data <> mvc.io.out.data.bits + io.out.tieoffRead() // write-only + + io.done := done + + if (debug) { + when (state === sReadUop && ~dec.reset) { + printf("[TensorGemm] [uop] idx:%x\n", uop_idx) + } + + when (state === sReadTensor && ~dec.reset) { + printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt) + } + + io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) => + when (io.inp.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt) + } + } + + io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) => + when (io.wgt.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt) + } + } + + io.acc.rd.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (io.acc.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem) + } + } + } + + mvc.io.acc_o.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (mvc.io.acc_o.data.valid && ~dec.reset) { + printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem) + } + } + } + + mvc.io.out.data.bits.foreach { tensor => + tensor.zipWithIndex.foreach { case(elem, i) => + when (mvc.io.out.data.valid && ~dec.reset) { + printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem) + } + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala new file mode 100644 index 000000000000..d96a681e7d69 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** TensorStore. + * + * Load 1D and 2D tensors from main memory (DRAM) to input/weight + * scratchpads (SRAM). Also, there is support for zero padding, while + * doing the load. Zero-padding works on the y and x axis, and it is + * managed by TensorPadCtrl. The TensorDataCtrl is in charge of + * handling the way tensors are stored on the scratchpads. + */ +class TensorLoad(tensorType: String = "none", debug: Boolean = false) + (implicit p: Parameters) extends Module { + val tp = new TensorParams(tensorType) + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val vme_rd = new VMEReadMaster + val tensor = new TensorClient(tensorType) + }) + val sizeFactor = tp.tensorLength * tp.numMemBlock + val strideFactor = tp.tensorLength * tp.tensorWidth + + val dec = io.inst.asTypeOf(new MemDecode) + val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor)) + val dataCtrlDone = RegInit(false.B) + val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor)) + val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor)) + val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor)) + val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor)) + + val tag = Reg(UInt(8.W)) + val set = Reg(UInt(8.W)) + + val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (io.start) { + when (dec.ypad_0 =/= 0.U) { + state := sYPad0 + } .elsewhen (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + is (sYPad0) { + when (yPadCtrl0.io.done) { + when (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + is (sXPad0) { + when (xPadCtrl0.io.done) { + state := sReadCmd + } + } + is (sReadCmd) { + when (io.vme_rd.cmd.ready) { + state := sReadData + } + } + is (sReadData) { + when (io.vme_rd.data.valid) { + when (dataCtrl.io.done) { + when (dec.xpad_1 =/= 0.U) { + state := sXPad1 + } .elsewhen (dec.ypad_1 =/= 0.U) { + state := sYPad1 + } .otherwise { + state := sIdle + } + } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) { + when (dec.xpad_1 =/= 0.U) { + state := sXPad1 + } .elsewhen (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + } + is (sXPad1) { + when (xPadCtrl1.io.done) { + when (dataCtrlDone) { + when (dec.ypad_1 =/= 0.U) { + state := sYPad1 + } .otherwise { + state := sIdle + } + } .otherwise { + when (dec.xpad_0 =/= 0.U) { + state := sXPad0 + } .otherwise { + state := sReadCmd + } + } + } + } + is (sYPad1) { + when (yPadCtrl1.io.done && dataCtrlDone) { + state := sIdle + } + } + } + + // data controller + dataCtrl.io.start := state === sIdle & io.start + dataCtrl.io.inst := io.inst + dataCtrl.io.baddr := io.baddr + dataCtrl.io.xinit := io.vme_rd.cmd.fire() + dataCtrl.io.xupdate := io.vme_rd.data.fire() + dataCtrl.io.yupdate := io.vme_rd.data.fire() + + when (state === sIdle) { + dataCtrlDone := false.B + } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) { + dataCtrlDone := true.B + } + + // pad + yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start + + yPadCtrl1.io.start := dec.ypad_1 =/= 0.U & + ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) | + (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone)) + + xPadCtrl0.io.start := dec.xpad_0 =/= 0.U & + ((state === sIdle & io.start) | + (state === sYPad0 & yPadCtrl0.io.done) | + (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) | + (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone)) + + xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() & + ((dataCtrl.io.done) | + (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U)) + + yPadCtrl0.io.inst := io.inst + yPadCtrl1.io.inst := io.inst + xPadCtrl0.io.inst := io.inst + xPadCtrl1.io.inst := io.inst + + // read-from-dram + io.vme_rd.cmd.valid := state === sReadCmd + io.vme_rd.cmd.bits.addr := dataCtrl.io.addr + io.vme_rd.cmd.bits.len := dataCtrl.io.len + + io.vme_rd.data.ready := state === sReadData + + // write-to-sram + val isZeroPad = state === sYPad0 | + state === sXPad0 | + state === sXPad1 | + state === sYPad1 + + when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) { + tag := 0.U + } .elsewhen (io.vme_rd.data.fire() || isZeroPad) { + tag := tag + 1.U + } + + when (state === sIdle || state === sReadCmd || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) { + set := 0.U + } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) { + set := set + 1.U + } + + val waddr_cur = Reg(UInt(tp.memAddrBits.W)) + val waddr_nxt = Reg(UInt(tp.memAddrBits.W)) + when (state === sIdle) { + waddr_cur := dec.sram_offset + waddr_nxt := dec.sram_offset + } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) { + waddr_cur := waddr_cur + 1.U + } .elsewhen (dataCtrl.io.stride) { + waddr_cur := waddr_nxt + dec.xsize + waddr_nxt := waddr_nxt + dec.xsize + } + + val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) } + val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) } + val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) } + val no_mask = Wire(Vec(tp.numMemBlock, Bool())) + no_mask.foreach { m => m := true.B } + + for (i <- 0 until tp.tensorLength) { + for (j <- 0 until tp.numMemBlock) { + wmask(i)(j) := tag === j.U + wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits) + } + val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i)) + val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U) + val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur) + val muxWdata = Mux(state === sIdle, tdata, wdata(i)) + val muxWmask = Mux(state === sIdle, no_mask, wmask(i)) + when (muxWen) { + tensorFile(i).write(muxWaddr, muxWdata, muxWmask) + } + } + + // read-from-sram + val rvalid = RegNext(io.tensor.rd.idx.valid) + io.tensor.rd.data.valid := rvalid + + val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid)) + rdata.zipWithIndex.foreach { case(r, i) => + io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i)) + } + + // done + val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U + val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U + val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done + io.done := done_no_pad | done_x_pad | done_y_pad + + // debug + if (debug) { + if (tensorType == "inp") { + when (io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + } + when (state === sYPad0) { + printf("[TensorLoad] [inp] sYPad0\n") + } + when (state === sYPad1) { + printf("[TensorLoad] [inp] sYPad1\n") + } + when (state === sXPad0) { + printf("[TensorLoad] [inp] sXPad0\n") + } + when (state === sXPad1) { + printf("[TensorLoad] [inp] sXPad1\n") + } + } else if (tensorType == "wgt") { + when (io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + } + } else if (tensorType == "acc") { + when (io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + } + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala new file mode 100644 index 000000000000..0012e4771c0e --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** TensorStore. + * + * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM). + */ +class TensorStore(tensorType: String = "true", debug: Boolean = false) + (implicit p: Parameters) extends Module { + val tp = new TensorParams(tensorType) + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val vme_wr = new VMEWriteMaster + val tensor = new TensorClient(tensorType) + }) + val tensorLength = tp.tensorLength + val tensorWidth = tp.tensorWidth + val tensorElemBits = tp.tensorElemBits + val memBlockBits = tp.memBlockBits + val memDepth = tp.memDepth + val numMemBlock = tp.numMemBlock + + val dec = io.inst.asTypeOf(new MemDecode) + val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr)) + val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr)) + val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) + val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) + val xrem = Reg(chiselTypeOf(dec.xsize)) + val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U + val xmax = (1 << mp.lenBits).U + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val ycnt = Reg(chiselTypeOf(dec.ysize)) + val ysize = dec.ysize + val tag = Reg(UInt(8.W)) + val set = Reg(UInt(8.W)) + + val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5) + val state = RegInit(sIdle) + + // control + switch (state) { + is (sIdle) { + when (io.start) { + state := sWriteCmd + when (xsize < xmax) { + xlen := xsize + xrem := 0.U + } .otherwise { + xlen := xmax - 1.U + xrem := xsize - xmax + } + } + } + is (sWriteCmd) { + when (io.vme_wr.cmd.ready) { + state := sWriteData + } + } + is (sWriteData) { + when (io.vme_wr.data.ready) { + when (xcnt === xlen) { + state := sWriteAck + } .elsewhen (tag === (numMemBlock - 1).U) { + state := sReadMem + } + } + } + is (sReadMem) { + state := sWriteData + } + is (sWriteAck) { + when (io.vme_wr.ack) { + when (xrem === 0.U) { + when (ycnt === ysize - 1.U) { + state := sIdle + } .otherwise { + state := sWriteCmd + when (xsize < xmax) { + xlen := xsize + xrem := 0.U + } .otherwise { + xlen := xmax - 1.U + xrem := xsize - xmax + } + } + } .elsewhen (xrem < xmax) { + state := sWriteCmd + xlen := xrem + xrem := 0.U + } .otherwise { + state := sWriteCmd + xlen := xmax - 1.U + xrem := xrem - xmax + } + } + } + } + + // write-to-sram + val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) } + val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W))) + val no_mask = Wire(Vec(numMemBlock, Bool())) + + wdata_t := DontCare + no_mask.foreach { m => m := true.B } + + for (i <- 0 until tensorLength) { + val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t) + when (io.tensor.wr.valid) { + tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask) + } + } + + // read-from-sram + val stride = state === sWriteAck & + io.vme_wr.ack & + xcnt === xlen + 1.U & + xrem === 0.U & + ycnt =/= ysize - 1.U + + when (state === sIdle) { + ycnt := 0.U + } .elsewhen (stride) { + ycnt := ycnt + 1.U + } + + when (state === sWriteCmd || tag === (numMemBlock - 1).U) { + tag := 0.U + } .elsewhen (io.vme_wr.data.fire()) { + tag := tag + 1.U + } + + when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) { + set := 0.U + } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) { + set := set + 1.U + } + + val raddr_cur = Reg(UInt(tp.memAddrBits.W)) + val raddr_nxt = Reg(UInt(tp.memAddrBits.W)) + when (state === sIdle) { + raddr_cur := dec.sram_offset + raddr_nxt := dec.sram_offset + } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) { + raddr_cur := raddr_cur + 1.U + } .elsewhen (stride) { + raddr_cur := raddr_nxt + dec.xsize + raddr_nxt := raddr_nxt + dec.xsize + } + + val tread = Seq.tabulate(tensorLength) { i => i.U -> + tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) } + val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread) + + // write-to-dram + when (state === sIdle) { + waddr_cur := io.baddr + dec.dram_offset + waddr_nxt := io.baddr + dec.dram_offset + } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) { + waddr_cur := waddr_cur + xmax_bytes + } .elsewhen (stride) { + waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) + waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) + } + + io.vme_wr.cmd.valid := state === sWriteCmd + io.vme_wr.cmd.bits.addr := waddr_cur + io.vme_wr.cmd.bits.len := xlen + + io.vme_wr.data.valid := state === sWriteData + io.vme_wr.data.bits := mdata(tag) + + when (state === sWriteCmd) { + xcnt := 0.U + } .elsewhen (io.vme_wr.data.fire()) { + xcnt := xcnt + 1.U + } + + // disable external read-from-sram requests + io.tensor.tieoffRead() + + // done + io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U + + // debug + if (debug) { + when (io.vme_wr.cmd.fire()) { + printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem) + } + when (io.vme_wr.data.fire()) { + printf("[TensorStore] data:%x\n", io.vme_wr.data.bits) + } + when (io.vme_wr.ack) { + printf("[TensorStore] ack\n") + } + } +} diff --git a/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala new file mode 100644 index 000000000000..e41a2c5b18e9 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** TensorParams. + * + * This Bundle derives parameters for each tensorType, including inputs (inp), + * weights (wgt), biases (acc), and outputs (out). This is used to avoid + * doing the same boring calculations over and over again. + */ +class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle { + val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n" + + require (tensorType == "inp" || tensorType == "wgt" + || tensorType == "acc" || tensorType == "out", errorMsg) + + val (tensorLength, tensorWidth, tensorElemBits) = + if (tensorType == "inp") + (p(CoreKey).batch, p(CoreKey).blockIn, p(CoreKey).inpBits) + else if (tensorType == "wgt") + (p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits) + else if (tensorType == "acc") + (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits) + else + (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits) + + val memBlockBits = p(ShellKey).memParams.dataBits + val numMemBlock = (tensorWidth * tensorElemBits) / memBlockBits + + val memDepth = + if (tensorType == "inp") + p(CoreKey).inpMemDepth + else if (tensorType == "wgt") + p(CoreKey).wgtMemDepth + else if (tensorType == "acc") + p(CoreKey).accMemDepth + else + p(CoreKey).outMemDepth + + val memAddrBits = log2Ceil(memDepth) +} + +/** TensorMaster. + * + * This interface issue read and write tensor-requests to scratchpads. For example, + * The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt), + * biases (acc), and outputs (out). + * + */ +class TensorMaster(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val rd = new Bundle { + val idx = ValidIO(UInt(memAddrBits.W)) + val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) + } + val wr = ValidIO(new Bundle { + val idx = UInt(memAddrBits.W) + val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) + }) + def tieoffRead() { + rd.idx.valid := false.B + rd.idx.bits := 0.U + } + def tieoffWrite() { + wr.valid := false.B + wr.bits.idx := 0.U + wr.bits.data.foreach { b => b.foreach { c => c := 0.U } } + } + override def cloneType = + new TensorMaster(tensorType).asInstanceOf[this.type] +} + +/** TensorClient. + * + * This interface receives read and write tensor-requests to scratchpads. For example, + * The TensorLoad unit uses this interface for receiving read and write requests from + * the TensorGemm unit. + */ +class TensorClient(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val rd = new Bundle { + val idx = Flipped(ValidIO(UInt(memAddrBits.W))) + val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) + } + val wr = Flipped(ValidIO(new Bundle { + val idx = UInt(memAddrBits.W) + val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) + })) + def tieoffRead() { + rd.data.valid := false.B + rd.data.bits.foreach { b => b.foreach { c => c := 0.U } } + } + override def cloneType = + new TensorClient(tensorType).asInstanceOf[this.type] +} + +/** TensorMasterData. + * + * This interface is only used for datapath only purposes and the direction convention + * is based on the TensorMaster interface, which means this is an input. This interface + * is used on datapath only module such MatrixVectorCore or AluVector. + */ +class TensorMasterData(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) + override def cloneType = + new TensorMasterData(tensorType).asInstanceOf[this.type] +} + +/** TensorClientData. + * + * This interface is only used for datapath only purposes and the direction convention + * is based on the TensorClient interface, which means this is an output. This interface + * is used on datapath only module such MatrixVectorCore or AluVector. + */ +class TensorClientData(tensorType: String = "none") + (implicit p: Parameters) extends TensorParams(tensorType) { + val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) + override def cloneType = + new TensorClientData(tensorType).asInstanceOf[this.type] +} + +/** TensorPadCtrl. Zero-padding controller for TensorLoad. */ +class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module { + val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n" + require (padType == "YPad0" || padType == "YPad1" + || padType == "XPad0" || padType == "XPad1", errorMsg) + + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + }) + + val dec = io.inst.asTypeOf(new MemDecode) + + val xmax = Reg(chiselTypeOf(dec.xsize)) + val ymax = Reg(chiselTypeOf(dec.ypad_0)) + val xcnt = Reg(chiselTypeOf(dec.xsize)) + val ycnt = Reg(chiselTypeOf(dec.ypad_0)) + + val xval = + if (padType == "YPad0" || padType == "YPad1") + ((dec.xpad_0 + dec.xsize + dec.xpad_1) << log2Ceil(sizeFactor)) - 1.U + else if (padType == "XPad0") + (dec.xpad_0 << log2Ceil(sizeFactor)) - 1.U + else + (dec.xpad_1 << log2Ceil(sizeFactor)) - 1.U + + val yval = + if (padType == "YPad0") + Mux(dec.ypad_0 =/= 0.U, dec.ypad_0 - 1.U, 0.U) + else if (padType == "YPad1") + Mux(dec.ypad_1 =/= 0.U, dec.ypad_1 - 1.U, 0.U) + else + 0.U + + val sIdle :: sActive :: Nil = Enum(2) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.start) { + state := sActive + } + } + is (sActive) { + when (ycnt === ymax && xcnt === xmax) { + state := sIdle + } + } + } + + when (state === sIdle) { + xmax := xval + ymax := yval + } + + when (state === sIdle || xcnt === xmax) { + xcnt := 0.U + } .elsewhen (state === sActive) { + xcnt := xcnt + 1.U + } + + when (state === sIdle || ymax === 0.U) { + ycnt := 0.U + } .elsewhen (state === sActive && xcnt === xmax) { + ycnt := ycnt + 1.U + } + + io.done := state === sActive & ycnt === ymax & xcnt === xmax +} + +/** TensorDataCtrl. Data controller for TensorLoad. */ +class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module { + val mp = p(ShellKey).memParams + val io = IO(new Bundle { + val start = Input(Bool()) + val done = Output(Bool()) + val inst = Input(UInt(INST_BITS.W)) + val baddr = Input(UInt(mp.addrBits.W)) + val xinit = Input(Bool()) + val xupdate = Input(Bool()) + val yupdate = Input(Bool()) + val stride = Output(Bool()) + val split = Output(Bool()) + val commit = Output(Bool()) + val addr = Output(UInt(mp.addrBits.W)) + val len = Output(UInt(mp.lenBits.W)) + }) + + val dec = io.inst.asTypeOf(new MemDecode) + + val caddr = Reg(UInt(mp.addrBits.W)) + val baddr = Reg(UInt(mp.addrBits.W)) + + val len = Reg(UInt(mp.lenBits.W)) + + val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val xcnt = Reg(UInt(mp.lenBits.W)) + val xrem = Reg(chiselTypeOf(dec.xsize)) + val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U + val xmax = (1 << mp.lenBits).U + val ycnt = Reg(chiselTypeOf(dec.ysize)) + + val stride = xcnt === len & + xrem === 0.U & + ycnt =/= dec.ysize - 1.U + + val split = xcnt === len & xrem =/= 0.U + + when (io.start || (io.xupdate && stride)) { + when (xsize < xmax) { + len := xsize + xrem := 0.U + } .otherwise { + len := xmax - 1.U + xrem := xsize - xmax + } + } .elsewhen (io.xupdate && split) { + when (xrem < xmax) { + len := xrem + xrem := 0.U + } .otherwise { + len := xmax - 1.U + xrem := xrem - xmax + } + } + + when (io.xinit) { + xcnt := 0.U + } .elsewhen (io.xupdate) { + xcnt := xcnt + 1.U + } + + when (io.start) { + ycnt := 0.U + } .elsewhen (io.yupdate && stride) { + ycnt := ycnt + 1.U + } + + when (io.start) { + caddr := io.baddr + dec.dram_offset + baddr := io.baddr + dec.dram_offset + } .elsewhen (io.yupdate) { + when (split) { + caddr := caddr + xmax_bytes + } .elsewhen (stride) { + caddr := baddr + (dec.xstride << log2Ceil(strideFactor)) + baddr := baddr + (dec.xstride << log2Ceil(strideFactor)) + } + } + + io.stride := stride + io.split := split + io.commit := xcnt === len + io.addr := caddr + io.len := len + io.done := xcnt === len & + xrem === 0.U & + ycnt === dec.ysize - 1.U +} diff --git a/vta/hardware/chisel/src/main/scala/core/package.scala b/vta/hardware/chisel/src/main/scala/core/package.scala new file mode 100644 index 000000000000..673d390901de --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta + +/** This trick makes ISAConstants globally available */ +package object core extends vta.core.ISAConstants diff --git a/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala b/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala index aab2d630c307..115bcbcb5a93 100644 --- a/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala +++ b/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala @@ -21,6 +21,9 @@ package vta.dpi import chisel3._ import chisel3.util._ +import vta.util.config._ +import vta.interface.axi._ +import vta.shell._ /** Host DPI parameters */ trait VTAHostDPIParams { @@ -70,3 +73,83 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource { }) setResource("/verilog/VTAHostDPI.v") } + +/** Host DPI to AXI Converter. + * + * Convert Host DPI to AXI for VTAShell + */ + +class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val dpi = new VTAHostDPIClient + val axi = new AXILiteMaster(p(ShellKey).hostParams) + }) + val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) + val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value))) + val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.dpi.req.valid) { + when (io.dpi.req.opcode) { + state := sWriteAddress + } .otherwise { + state := sReadAddress + } + } + } + is (sReadAddress) { + when (io.axi.ar.ready) { + state := sReadData + } + } + is (sReadData) { + when (io.axi.r.valid) { + state := sIdle + } + } + is (sWriteAddress) { + when (io.axi.aw.ready) { + state := sWriteData + } + } + is (sWriteData) { + when (io.axi.w.ready) { + state := sWriteResponse + } + } + is (sWriteResponse) { + when (io.axi.b.valid) { + state := sIdle + } + } + } + + when (state === sIdle && io.dpi.req.valid) { + addr := io.dpi.req.addr + data := io.dpi.req.value + } + + io.axi.aw.valid := state === sWriteAddress + io.axi.aw.bits.addr := addr + io.axi.w.valid := state === sWriteData + io.axi.w.bits.data := data + io.axi.w.bits.strb := "h_f".U + io.axi.b.ready := state === sWriteResponse + + io.axi.ar.valid := state === sReadAddress + io.axi.ar.bits.addr := addr + io.axi.r.ready := state === sReadData + + io.dpi.req.deq := (state === sReadAddress & io.axi.ar.ready) | (state === sWriteAddress & io.axi.aw.ready) + io.dpi.resp.valid := io.axi.r.valid + io.dpi.resp.bits := io.axi.r.bits.data + + if (debug) { + when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) } + when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) } + when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) } + when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) } + } +} diff --git a/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala b/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala index 090f0459570a..5e2fa741d72a 100644 --- a/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala +++ b/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala @@ -21,6 +21,9 @@ package vta.dpi import chisel3._ import chisel3.util._ +import vta.util.config._ +import vta.interface.axi._ +import vta.shell._ /** Memory DPI parameters */ trait VTAMemDPIParams { @@ -71,3 +74,98 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource { }) setResource("/verilog/VTAMemDPI.v") } + +class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val dpi = new VTAMemDPIMaster + val axi = new AXIClient(p(ShellKey).memParams) + }) + val opcode = RegInit(false.B) + val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len))) + val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) + val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) + val state = RegInit(sIdle) + + switch (state) { + is (sIdle) { + when (io.axi.ar.valid) { + state := sReadAddress + } .elsewhen (io.axi.aw.valid) { + state := sWriteAddress + } + } + is (sReadAddress) { + when (io.axi.ar.valid) { + state := sReadData + } + } + is (sReadData) { + when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) { + state := sIdle + } + } + is (sWriteAddress) { + when (io.axi.aw.valid) { + state := sWriteData + } + } + is (sWriteData) { + when (io.axi.w.valid && io.axi.w.bits.last) { + state := sWriteResponse + } + } + is (sWriteResponse) { + when (io.axi.b.ready) { + state := sIdle + } + } + } + + when (state === sIdle) { + when (io.axi.ar.valid) { + opcode := false.B + len := io.axi.ar.bits.len + addr := io.axi.ar.bits.addr + } .elsewhen (io.axi.aw.valid) { + opcode := true.B + len := io.axi.aw.bits.len + addr := io.axi.aw.bits.addr + } + } .elsewhen (state === sReadData) { + when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) { + len := len - 1.U + } + } + + io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid) + io.dpi.req.opcode := opcode + io.dpi.req.len := len + io.dpi.req.addr := addr + + io.axi.ar.ready := state === sReadAddress + io.axi.aw.ready := state === sWriteAddress + + io.axi.r.valid := state === sReadData & io.dpi.rd.valid + io.axi.r.bits.data := io.dpi.rd.bits + io.axi.r.bits.last := len === 0.U + io.axi.r.bits.resp := 0.U + io.axi.r.bits.user := 0.U + io.axi.r.bits.id := 0.U + io.dpi.rd.ready := state === sReadData & io.axi.r.ready + + io.dpi.wr.valid := state === sWriteData & io.axi.w.valid + io.dpi.wr.bits := io.axi.w.bits.data + io.axi.w.ready := state === sWriteData + + io.axi.b.valid := state === sWriteResponse + io.axi.b.bits.resp := 0.U + io.axi.b.bits.user := 0.U + io.axi.b.bits.id := 0.U + + if (debug) { + when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) } + when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) } + when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) } + when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) } + } +} diff --git a/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala new file mode 100644 index 000000000000..a853e85e2bd8 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.interface.axi + +import chisel3._ +import chisel3.util._ +import vta.util.genericbundle._ + +case class AXIParams( + addrBits: Int = 32, + dataBits: Int = 64 +) +{ + require (addrBits > 0) + require (dataBits >= 8 && dataBits % 2 == 0) + + val idBits = 1 + val userBits = 1 + val strbBits = dataBits/8 + val lenBits = 8 + val sizeBits = 3 + val burstBits = 2 + val lockBits = 2 + val cacheBits = 4 + val protBits = 3 + val qosBits = 4 + val regionBits = 4 + val respBits = 2 + val sizeConst = log2Ceil(dataBits/8) + val idConst = 0 + val userConst = 0 + val burstConst = 1 + val lockConst = 0 + val cacheConst = 3 + val protConst = 0 + val qosConst = 0 + val regionConst = 0 +} + +abstract class AXIBase(params: AXIParams) + extends GenericParameterizedBundle(params) + +// AXILite + +class AXILiteAddress(params: AXIParams) extends AXIBase(params) { + val addr = UInt(params.addrBits.W) +} + +class AXILiteWriteData(params: AXIParams) extends AXIBase(params) { + val data = UInt(params.dataBits.W) + val strb = UInt(params.strbBits.W) +} + +class AXILiteWriteResponse(params: AXIParams) extends AXIBase(params) { + val resp = UInt(params.respBits.W) +} + +class AXILiteReadData(params: AXIParams) extends AXIBase(params) { + val data = UInt(params.dataBits.W) + val resp = UInt(params.respBits.W) +} + +class AXILiteMaster(params: AXIParams) extends AXIBase(params) { + val aw = Decoupled(new AXILiteAddress(params)) + val w = Decoupled(new AXILiteWriteData(params)) + val b = Flipped(Decoupled(new AXILiteWriteResponse(params))) + val ar = Decoupled(new AXILiteAddress(params)) + val r = Flipped(Decoupled(new AXILiteReadData(params))) + + def tieoff() { + aw.valid := false.B + aw.bits.addr := 0.U + w.valid := false.B + w.bits.data := 0.U + w.bits.strb := 0.U + b.ready := false.B + ar.valid := false.B + ar.bits.addr := 0.U + r.ready := false.B + } +} + +class AXILiteClient(params: AXIParams) extends AXIBase(params) { + val aw = Flipped(Decoupled(new AXILiteAddress(params))) + val w = Flipped(Decoupled(new AXILiteWriteData(params))) + val b = Decoupled(new AXILiteWriteResponse(params)) + val ar = Flipped(Decoupled(new AXILiteAddress(params))) + val r = Decoupled(new AXILiteReadData(params)) + + def tieoff() { + aw.ready := false.B + w.ready := false.B + b.valid := false.B + b.bits.resp := 0.U + ar.ready := false.B + r.valid := false.B + r.bits.resp := 0.U + r.bits.data := 0.U + } +} + +// AXI extends AXILite + +class AXIAddress(params: AXIParams) extends AXILiteAddress(params) { + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) + val len = UInt(params.lenBits.W) + val size = UInt(params.sizeBits.W) + val burst = UInt(params.burstBits.W) + val lock = UInt(params.lockBits.W) + val cache = UInt(params.cacheBits.W) + val prot = UInt(params.protBits.W) + val qos = UInt(params.qosBits.W) + val region = UInt(params.regionBits.W) +} + +class AXIWriteData(params: AXIParams) extends AXILiteWriteData(params) { + val last = Bool() + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) +} + +class AXIWriteResponse(params: AXIParams) extends AXILiteWriteResponse(params) { + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) +} + +class AXIReadData(params: AXIParams) extends AXILiteReadData(params) { + val last = Bool() + val id = UInt(params.idBits.W) + val user = UInt(params.userBits.W) +} + +class AXIMaster(params: AXIParams) extends AXIBase(params) { + val aw = Decoupled(new AXIAddress(params)) + val w = Decoupled(new AXIWriteData(params)) + val b = Flipped(Decoupled(new AXIWriteResponse(params))) + val ar = Decoupled(new AXIAddress(params)) + val r = Flipped(Decoupled(new AXIReadData(params))) + + def tieoff() { + aw.valid := false.B + aw.bits.addr := 0.U + aw.bits.id := 0.U + aw.bits.user := 0.U + aw.bits.len := 0.U + aw.bits.size := 0.U + aw.bits.burst := 0.U + aw.bits.lock := 0.U + aw.bits.cache := 0.U + aw.bits.prot := 0.U + aw.bits.qos := 0.U + aw.bits.region := 0.U + w.valid := false.B + w.bits.data := 0.U + w.bits.strb := 0.U + w.bits.last := false.B + w.bits.id := 0.U + w.bits.user := 0.U + b.ready := false.B + ar.valid := false.B + ar.bits.addr := 0.U + ar.bits.id := 0.U + ar.bits.user := 0.U + ar.bits.len := 0.U + ar.bits.size := 0.U + ar.bits.burst := 0.U + ar.bits.lock := 0.U + ar.bits.cache := 0.U + ar.bits.prot := 0.U + ar.bits.qos := 0.U + ar.bits.region := 0.U + r.ready := false.B + } + + def setConst() { + aw.bits.user := params.userConst.U + aw.bits.burst := params.burstConst.U + aw.bits.lock := params.lockConst.U + aw.bits.cache := params.cacheConst.U + aw.bits.prot := params.protConst.U + aw.bits.qos := params.qosConst.U + aw.bits.region := params.regionConst.U + aw.bits.size := params.sizeConst.U + aw.bits.id := params.idConst.U + w.bits.id := params.idConst.U + w.bits.user := params.userConst.U + w.bits.strb := Fill(params.strbBits, true.B) + ar.bits.user := params.userConst.U + ar.bits.burst := params.burstConst.U + ar.bits.lock := params.lockConst.U + ar.bits.cache := params.cacheConst.U + ar.bits.prot := params.protConst.U + ar.bits.qos := params.qosConst.U + ar.bits.region := params.regionConst.U + ar.bits.size := params.sizeConst.U + ar.bits.id := params.idConst.U + } +} + +class AXIClient(params: AXIParams) extends AXIBase(params) { + val aw = Flipped(Decoupled(new AXIAddress(params))) + val w = Flipped(Decoupled(new AXIWriteData(params))) + val b = Decoupled(new AXIWriteResponse(params)) + val ar = Flipped(Decoupled(new AXIAddress(params))) + val r = Decoupled(new AXIReadData(params)) + + def tieoff() { + aw.ready := false.B + w.ready := false.B + b.valid := false.B + b.bits.resp := 0.U + b.bits.user := 0.U + b.bits.id := 0.U + ar.ready := false.B + r.valid := false.B + r.bits.resp := 0.U + r.bits.data := 0.U + r.bits.user := 0.U + r.bits.last := false.B + r.bits.id := 0.U + } +} + +// XilinxAXILiteClient and XilinxAXIMaster bundles are needed +// for wrapper purposes, because the package RTL tool in Xilinx Vivado +// only allows certain name formats + +class XilinxAXILiteClient(params: AXIParams) extends AXIBase(params) { + val AWVALID = Input(Bool()) + val AWREADY = Output(Bool()) + val AWADDR = Input(UInt(params.addrBits.W)) + val WVALID = Input(Bool()) + val WREADY = Output(Bool()) + val WDATA = Input(UInt(params.dataBits.W)) + val WSTRB = Input(UInt(params.strbBits.W)) + val BVALID = Output(Bool()) + val BREADY = Input(Bool()) + val BRESP = Output(UInt(params.respBits.W)) + val ARVALID = Input(Bool()) + val ARREADY = Output(Bool()) + val ARADDR = Input(UInt(params.addrBits.W)) + val RVALID = Output(Bool()) + val RREADY = Input(Bool()) + val RDATA = Output(UInt(params.dataBits.W)) + val RRESP = Output(UInt(params.respBits.W)) +} + +class XilinxAXIMaster(params: AXIParams) extends AXIBase(params) { + val AWVALID = Output(Bool()) + val AWREADY = Input(Bool()) + val AWADDR = Output(UInt(params.addrBits.W)) + val AWID = Output(UInt(params.idBits.W)) + val AWUSER = Output(UInt(params.userBits.W)) + val AWLEN = Output(UInt(params.lenBits.W)) + val AWSIZE = Output(UInt(params.sizeBits.W)) + val AWBURST = Output(UInt(params.burstBits.W)) + val AWLOCK = Output(UInt(params.lockBits.W)) + val AWCACHE = Output(UInt(params.cacheBits.W)) + val AWPROT = Output(UInt(params.protBits.W)) + val AWQOS = Output(UInt(params.qosBits.W)) + val AWREGION = Output(UInt(params.regionBits.W)) + val WVALID = Output(Bool()) + val WREADY = Input(Bool()) + val WDATA = Output(UInt(params.dataBits.W)) + val WSTRB = Output(UInt(params.strbBits.W)) + val WLAST = Output(Bool()) + val WID = Output(UInt(params.idBits.W)) + val WUSER = Output(UInt(params.userBits.W)) + val BVALID = Input(Bool()) + val BREADY = Output(Bool()) + val BRESP = Input(UInt(params.respBits.W)) + val BID = Input(UInt(params.idBits.W)) + val BUSER = Input(UInt(params.userBits.W)) + val ARVALID = Output(Bool()) + val ARREADY = Input(Bool()) + val ARADDR = Output(UInt(params.addrBits.W)) + val ARID = Output(UInt(params.idBits.W)) + val ARUSER = Output(UInt(params.userBits.W)) + val ARLEN = Output(UInt(params.lenBits.W)) + val ARSIZE = Output(UInt(params.sizeBits.W)) + val ARBURST = Output(UInt(params.burstBits.W)) + val ARLOCK = Output(UInt(params.lockBits.W)) + val ARCACHE = Output(UInt(params.cacheBits.W)) + val ARPROT = Output(UInt(params.protBits.W)) + val ARQOS = Output(UInt(params.qosBits.W)) + val ARREGION = Output(UInt(params.regionBits.W)) + val RVALID = Input(Bool()) + val RREADY = Output(Bool()) + val RDATA = Input(UInt(params.dataBits.W)) + val RRESP = Input(UInt(params.respBits.W)) + val RLAST = Input(Bool()) + val RID = Input(UInt(params.idBits.W)) + val RUSER = Input(UInt(params.userBits.W)) +} diff --git a/vta/hardware/chisel/src/main/scala/shell/Configs.scala b/vta/hardware/chisel/src/main/scala/shell/Configs.scala new file mode 100644 index 000000000000..1d1d5223b73c --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/Configs.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.interface.axi._ + +/** PynqConfig. Shell configuration for Pynq */ +class PynqConfig extends Config((site, here, up) => { + case ShellKey => ShellParams( + hostParams = AXIParams( + addrBits = 16, + dataBits = 32), + memParams = AXIParams( + addrBits = 32, + dataBits = 64), + vcrParams = VCRParams(), + vmeParams = VMEParams()) +}) + +/** F1Config. Shell configuration for F1 */ +class F1Config extends Config((site, here, up) => { + case ShellKey => ShellParams( + hostParams = AXIParams( + addrBits = 16, + dataBits = 32), + memParams = AXIParams( + addrBits = 64, + dataBits = 64), + vcrParams = VCRParams(), + vmeParams = VMEParams()) +}) diff --git a/vta/hardware/chisel/src/main/scala/shell/SimShell.scala b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala new file mode 100644 index 000000000000..3ad4b6548ce3 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import vta.util.config._ +import vta.interface.axi._ +import vta.shell._ +import vta.dpi._ + +/** VTAHost. + * + * This module translate the DPI protocol into AXI. This is a simulation only + * module and used to test host-to-VTA communication. This module should be updated + * for testing hosts using a different bus protocol, other than AXI. + */ +class VTAHost(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val axi = new AXILiteMaster(p(ShellKey).hostParams) + }) + val host_dpi = Module(new VTAHostDPI) + val host_axi = Module(new VTAHostDPIToAXI) + host_dpi.io.reset := reset + host_dpi.io.clock := clock + host_axi.io.dpi <> host_dpi.io.dpi + io.axi <> host_axi.io.axi +} + +/** VTAMem. + * + * This module translate the DPI protocol into AXI. This is a simulation only + * module and used to test VTA-to-memory communication. This module should be updated + * for testing memories using a different bus protocol, other than AXI. + */ +class VTAMem(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val axi = new AXIClient(p(ShellKey).memParams) + }) + val mem_dpi = Module(new VTAMemDPI) + val mem_axi = Module(new VTAMemDPIToAXI) + mem_dpi.io.reset := reset + mem_dpi.io.clock := clock + mem_dpi.io.dpi <> mem_axi.io.dpi + mem_axi.io.axi <> io.axi +} + +/** SimShell. + * + * The simulation shell instantiate a host and memory simulation modules and it is + * intended to be connected to the VTAShell. + */ +class SimShell(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val mem = new AXIClient(p(ShellKey).memParams) + val host = new AXILiteMaster(p(ShellKey).hostParams) + }) + val host = Module(new VTAHost) + val mem = Module(new VTAMem) + io.mem <> mem.io.axi + io.host <> host.io.axi +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VCR.scala b/vta/hardware/chisel/src/main/scala/shell/VCR.scala new file mode 100644 index 000000000000..463f55bc8bbd --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/VCR.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.util.genericbundle._ +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.LinkedHashMap +import vta.interface.axi._ + +/** VCR parameters. + * + * These parameters are used on VCR interfaces and modules. + */ +case class VCRParams() +{ + val nValsReg: Int = 1 + val nPtrsReg: Int = 6 + val regBits: Int = 32 + val nCtrlReg: Int = 4 + val ctrlBaseAddr: Int = 0 + + require (nValsReg > 0) + require (nPtrsReg > 0) +} + +/** VCRBase. Parametrize base class. */ +abstract class VCRBase(implicit p: Parameters) + extends GenericParameterizedBundle(p) + +/** VCRMaster. + * + * This is the master interface used by VCR in the VTAShell to control + * the Core unit. + */ +class VCRMaster(implicit p: Parameters) extends VCRBase { + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val launch = Output(Bool()) + val finish = Input(Bool()) + val irq = Output(Bool()) + val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) + val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W))) +} + +/** VCRClient. + * + * This is the client interface used by the Core module to communicate + * to the VCR in the VTAShell. + */ +class VCRClient(implicit p: Parameters) extends VCRBase { + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val launch = Input(Bool()) + val finish = Output(Bool()) + val irq = Input(Bool()) + val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) + val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W))) +} + +/** VTA Control Registers (VCR). + * + * This unit provides control registers (32 and 64 bits) to be used by a control' + * unit, typically a host processor. These registers are read-only by the core + * at the moment but this will likely change once we add support to general purpose + * registers that could be used as event counters by the Core unit. + */ +class VCR(implicit p: Parameters) extends Module { + val io = IO(new Bundle{ + val host = new AXILiteClient(p(ShellKey).hostParams) + val vcr = new VCRMaster + }) + + val vp = p(ShellKey).vcrParams + val mp = p(ShellKey).memParams + val hp = p(ShellKey).hostParams + + // Write control (AW, W, B) + val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address + val wdata = io.host.w.bits.data + val wstrb = io.host.w.bits.strb + val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0))) + val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3) + val wstate = RegInit(sWriteAddress) + switch (wstate) { + is (sWriteAddress) { + when (io.host.aw.valid) { + wstate := sWriteData + } + } + is (sWriteData) { + when (io.host.w.valid) { + wstate := sWriteResponse + } + } + is (sWriteResponse) { + when (io.host.b.ready) { + wstate := sWriteAddress + } + } + } + + when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr } + + io.host.aw.ready := wstate === sWriteAddress + io.host.w.ready := wstate === sWriteData + io.host.b.valid := wstate === sWriteResponse + io.host.b.bits.resp := "h_0".U + + // read control (AR, R) + val sReadAddress :: sReadData :: Nil = Enum(2) + val rstate = RegInit(sReadAddress) + + switch (rstate) { + is (sReadAddress) { + when (io.host.ar.valid) { + rstate := sReadData + } + } + is (sReadData) { + when (io.host.r.ready) { + rstate := sReadAddress + } + } + } + + io.host.ar.ready := rstate === sReadAddress + io.host.r.valid := rstate === sReadData + + val nPtrsReg = vp.nPtrsReg + val nValsReg = vp.nValsReg + val regBits = vp.regBits + val ptrsBits = mp.addrBits + val nCtrlReg = vp.nCtrlReg + val rStride = regBits/8 + val pStride = ptrsBits/8 + val ctrlBaseAddr = vp.ctrlBaseAddr + val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride + val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride + + val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr) + val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr) + + val ptrsAddr = new ListBuffer[Int]() + for (i <- 0 until nPtrsReg) { + ptrsAddr += i*pStride + ptrsBaseAddr + if (ptrsBits == 64) { + ptrsAddr += i*pStride + rStride + ptrsBaseAddr + } + } + + // AP register + val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B))) + + // ap start + when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) { + c0(0) := true.B + } .elsewhen (io.vcr.finish) { + c0(0) := false.B + } + + // ap done = finish + when (io.vcr.finish) { + c0(1) := true.B + } .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) { + c0(1) := false.B + } + + val c1 = 0.U + val c2 = 0.U + val c3 = 0.U + + val ctrlRegList = List(c0, c1, c2, c3) + + io.vcr.launch := c0(0) + + // interrupts not supported atm + io.vcr.irq := false.B + + // Write pointer and value registers + val pvAddr = valsAddr ++ ptrsAddr + val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg + val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W)))) + val pvRegList = new ListBuffer[UInt]() + + for (i <- 0 until pvNumReg) { + when (io.host.w.fire() && (waddr === pvAddr(i).U)) { + pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask) + } + pvRegList += pvReg(i) + } + + for (i <- 0 until nValsReg) { + io.vcr.vals(i) := pvReg(i) + } + + for (i <- 0 until nPtrsReg) { + if (ptrsBits == 64) { + io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2)) + } else { + io.vcr.ptrs(i) := pvReg(nValsReg + i) + } + } + + // Read pointer and value registers + val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr + val mapRegList = ctrlRegList ++ pvRegList + + val rdata = RegInit(0.U(regBits.W)) + val rmap = LinkedHashMap[Int,UInt]() + + val totalReg = mapRegList.length + for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt } + + val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) } + + when (io.host.ar.fire()) { + rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v) + } + + io.host.r.bits.resp := 0.U + io.host.r.bits.data := rdata +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VME.scala b/vta/hardware/chisel/src/main/scala/shell/VME.scala new file mode 100644 index 000000000000..862e9810c510 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/VME.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.util.genericbundle._ +import vta.interface.axi._ + +/** VME parameters. + * + * These parameters are used on VME interfaces and modules. + */ +case class VMEParams() { + val nReadClients: Int = 5 + val nWriteClients: Int = 1 + require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n") + require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n") +} + +/** VMEBase. Parametrize base class. */ +abstract class VMEBase(implicit p: Parameters) + extends GenericParameterizedBundle(p) + +/** VMECmd. + * + * This interface is used for creating write and read requests to memory. + */ +class VMECmd(implicit p: Parameters) extends VMEBase { + val addrBits = p(ShellKey).memParams.addrBits + val lenBits = p(ShellKey).memParams.lenBits + val addr = UInt(addrBits.W) + val len = UInt(lenBits.W) +} + +/** VMEReadMaster. + * + * This interface is used by modules inside the core to generate read requests + * and receive responses from VME. + */ +class VMEReadMaster(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Decoupled(new VMECmd) + val data = Flipped(Decoupled(UInt(dataBits.W))) + override def cloneType = + new VMEReadMaster().asInstanceOf[this.type] +} + +/** VMEReadClient. + * + * This interface is used by the VME to receive read requests and generate + * responses to modules inside the core. + */ +class VMEReadClient(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Flipped(Decoupled(new VMECmd)) + val data = Decoupled(UInt(dataBits.W)) + override def cloneType = + new VMEReadClient().asInstanceOf[this.type] +} + +/** VMEWriteMaster. + * + * This interface is used by modules inside the core to generate write requests + * to the VME. + */ +class VMEWriteMaster(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Decoupled(new VMECmd) + val data = Decoupled(UInt(dataBits.W)) + val ack = Input(Bool()) + override def cloneType = + new VMEWriteMaster().asInstanceOf[this.type] +} + +/** VMEWriteClient. + * + * This interface is used by the VME to handle write requests from modules inside + * the core. + */ +class VMEWriteClient(implicit p: Parameters) extends Bundle { + val dataBits = p(ShellKey).memParams.dataBits + val cmd = Flipped(Decoupled(new VMECmd)) + val data = Flipped(Decoupled(UInt(dataBits.W))) + val ack = Output(Bool()) + override def cloneType = + new VMEWriteClient().asInstanceOf[this.type] +} + +/** VMEMaster. + * + * Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster + * interfaces. + */ +class VMEMaster(implicit p: Parameters) extends Bundle { + val nRd = p(ShellKey).vmeParams.nReadClients + val nWr = p(ShellKey).vmeParams.nWriteClients + val rd = Vec(nRd, new VMEReadMaster) + val wr = Vec(nWr, new VMEWriteMaster) +} + +/** VMEClient. + * + * Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient + * interfaces. + */ +class VMEClient(implicit p: Parameters) extends Bundle { + val nRd = p(ShellKey).vmeParams.nReadClients + val nWr = p(ShellKey).vmeParams.nWriteClients + val rd = Vec(nRd, new VMEReadClient) + val wr = Vec(nWr, new VMEWriteClient) +} + +/** VTA Memory Engine (VME). + * + * This unit multiplexes the memory controller interface for the Core. Currently, + * it supports single-writer and multiple-reader mode and it is also based on AXI. + */ +class VME(implicit p: Parameters) extends Module { + val io = IO(new Bundle { + val mem = new AXIMaster(p(ShellKey).memParams) + val vme = new VMEClient + }) + + val nReadClients = p(ShellKey).vmeParams.nReadClients + val rd_arb = Module(new Arbiter(new VMECmd, nReadClients)) + val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire()) + + for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd } + + val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3) + val rstate = RegInit(sReadIdle) + + switch (rstate) { + is (sReadIdle) { + when (rd_arb.io.out.valid) { + rstate := sReadAddr + } + } + is (sReadAddr) { + when (io.mem.ar.ready) { + rstate := sReadData + } + } + is (sReadData) { + when (io.mem.r.fire() && io.mem.r.bits.last) { + rstate := sReadIdle + } + } + } + + val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4) + val wstate = RegInit(sWriteIdle) + val addrBits = p(ShellKey).memParams.addrBits + val lenBits = p(ShellKey).memParams.lenBits + val wr_cnt = RegInit(0.U(lenBits.W)) + + when (wstate === sWriteIdle) { + wr_cnt := 0.U + } .elsewhen (io.mem.w.fire()) { + wr_cnt := wr_cnt + 1.U + } + + switch (wstate) { + is (sWriteIdle) { + when (io.vme.wr(0).cmd.valid) { + wstate := sWriteAddr + } + } + is (sWriteAddr) { + when (io.mem.aw.ready) { + wstate := sWriteData + } + } + is (sWriteData) { + when (io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) { + wstate := sWriteResp + } + } + is (sWriteResp) { + when (io.mem.b.valid) { + wstate := sWriteIdle + } + } + } + + // registers storing read/write cmds + + val rd_len = RegInit(0.U(lenBits.W)) + val wr_len = RegInit(0.U(lenBits.W)) + val rd_addr = RegInit(0.U(addrBits.W)) + val wr_addr = RegInit(0.U(addrBits.W)) + + when (rd_arb.io.out.fire()) { + rd_len := rd_arb.io.out.bits.len + rd_addr := rd_arb.io.out.bits.addr + } + + when (io.vme.wr(0).cmd.fire()) { + wr_len := io.vme.wr(0).cmd.bits.len + wr_addr := io.vme.wr(0).cmd.bits.addr + } + + // rd arb + rd_arb.io.out.ready := rstate === sReadIdle + + // vme + for (i <- 0 until nReadClients) { + io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid + io.vme.rd(i).data.bits := io.mem.r.bits.data + } + + io.vme.wr(0).cmd.ready := wstate === sWriteIdle + io.vme.wr(0).ack := io.mem.b.fire() + io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready + + // mem + io.mem.aw.valid := wstate === sWriteAddr + io.mem.aw.bits.addr := wr_addr + io.mem.aw.bits.len := wr_len + + io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid + io.mem.w.bits.data := io.vme.wr(0).data.bits + io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len + + io.mem.b.ready := wstate === sWriteResp + + io.mem.ar.valid := rstate === sReadAddr + io.mem.ar.bits.addr := rd_addr + io.mem.ar.bits.len := rd_len + + io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready + + // AXI constants - statically defined + io.mem.setConst() +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala new file mode 100644 index 000000000000..c8093118308f --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import vta.util.config._ +import vta.interface.axi._ +import vta.core._ + +/** Shell parameters. */ +case class ShellParams( + hostParams: AXIParams, + memParams: AXIParams, + vcrParams: VCRParams, + vmeParams: VMEParams +) + +case object ShellKey extends Field[ShellParams] + +/** VTAShell. + * + * The VTAShell is based on a VME, VCR and core. This creates a complete VTA + * system that can be used for simulation or real hardware. + */ +class VTAShell(implicit p: Parameters) extends Module { + val io = IO(new Bundle{ + val host = new AXILiteClient(p(ShellKey).hostParams) + val mem = new AXIMaster(p(ShellKey).memParams) + }) + + val vcr = Module(new VCR) + val vme = Module(new VME) + val core = Module(new Core) + + core.io.vcr <> vcr.io.vcr + vme.io.vme <> core.io.vme + + vcr.io.host <> io.host + io.mem <> vme.io.mem +} diff --git a/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala new file mode 100644 index 000000000000..db721373b7e3 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.shell + +import chisel3._ +import chisel3.experimental.{RawModule, withClockAndReset} +import vta.util.config._ +import vta.interface.axi._ + +/** XilinxShell. + * + * This is a wrapper shell mostly used to match Xilinx convention naming, + * therefore we can pack VTA as an IP for IPI based flows. + */ +class XilinxShell(implicit p: Parameters) extends RawModule { + + val hp = p(ShellKey).hostParams + val mp = p(ShellKey).memParams + + val ap_clk = IO(Input(Clock())) + val ap_rst_n = IO(Input(Bool())) + val m_axi_gmem = IO(new XilinxAXIMaster(mp)) + val s_axi_control = IO(new XilinxAXILiteClient(hp)) + + val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) } + + // memory + m_axi_gmem.AWVALID := shell.io.mem.aw.valid + shell.io.mem.aw.ready := m_axi_gmem.AWREADY + m_axi_gmem.AWADDR := shell.io.mem.aw.bits.addr + m_axi_gmem.AWID := shell.io.mem.aw.bits.id + m_axi_gmem.AWUSER := shell.io.mem.aw.bits.user + m_axi_gmem.AWLEN := shell.io.mem.aw.bits.len + m_axi_gmem.AWSIZE := shell.io.mem.aw.bits.size + m_axi_gmem.AWBURST := shell.io.mem.aw.bits.burst + m_axi_gmem.AWLOCK := shell.io.mem.aw.bits.lock + m_axi_gmem.AWCACHE := shell.io.mem.aw.bits.cache + m_axi_gmem.AWPROT := shell.io.mem.aw.bits.prot + m_axi_gmem.AWQOS := shell.io.mem.aw.bits.qos + m_axi_gmem.AWREGION := shell.io.mem.aw.bits.region + + m_axi_gmem.WVALID := shell.io.mem.w.valid + shell.io.mem.w.ready := m_axi_gmem.WREADY + m_axi_gmem.WDATA := shell.io.mem.w.bits.data + m_axi_gmem.WSTRB := shell.io.mem.w.bits.strb + m_axi_gmem.WLAST := shell.io.mem.w.bits.last + m_axi_gmem.WID := shell.io.mem.w.bits.id + m_axi_gmem.WUSER := shell.io.mem.w.bits.user + + shell.io.mem.b.valid := m_axi_gmem.BVALID + m_axi_gmem.BREADY := shell.io.mem.b.valid + shell.io.mem.b.bits.resp := m_axi_gmem.BRESP + shell.io.mem.b.bits.id := m_axi_gmem.BID + shell.io.mem.b.bits.user := m_axi_gmem.BUSER + + m_axi_gmem.ARVALID := shell.io.mem.ar.valid + shell.io.mem.ar.ready := m_axi_gmem.ARREADY + m_axi_gmem.ARADDR := shell.io.mem.ar.bits.addr + m_axi_gmem.ARID := shell.io.mem.ar.bits.id + m_axi_gmem.ARUSER := shell.io.mem.ar.bits.user + m_axi_gmem.ARLEN := shell.io.mem.ar.bits.len + m_axi_gmem.ARSIZE := shell.io.mem.ar.bits.size + m_axi_gmem.ARBURST := shell.io.mem.ar.bits.burst + m_axi_gmem.ARLOCK := shell.io.mem.ar.bits.lock + m_axi_gmem.ARCACHE := shell.io.mem.ar.bits.cache + m_axi_gmem.ARPROT := shell.io.mem.ar.bits.prot + m_axi_gmem.ARQOS := shell.io.mem.ar.bits.qos + m_axi_gmem.ARREGION := shell.io.mem.ar.bits.region + + shell.io.mem.r.valid := m_axi_gmem.RVALID + m_axi_gmem.RREADY := shell.io.mem.r.ready + shell.io.mem.r.bits.data := m_axi_gmem.RDATA + shell.io.mem.r.bits.resp := m_axi_gmem.RRESP + shell.io.mem.r.bits.last := m_axi_gmem.RLAST + shell.io.mem.r.bits.id := m_axi_gmem.RID + shell.io.mem.r.bits.user := m_axi_gmem.RUSER + + // host + shell.io.host.aw.valid := s_axi_control.AWVALID + s_axi_control.AWREADY := shell.io.host.aw.ready + shell.io.host.aw.bits.addr := s_axi_control.AWADDR + + shell.io.host.w.valid := s_axi_control.WVALID + s_axi_control.WREADY := shell.io.host.w.ready + shell.io.host.w.bits.data := s_axi_control.WDATA + shell.io.host.w.bits.strb := s_axi_control.WSTRB + + s_axi_control.BVALID := shell.io.host.b.valid + shell.io.host.b.ready := s_axi_control.BREADY + s_axi_control.BRESP := shell.io.host.b.bits.resp + + shell.io.host.ar.valid := s_axi_control.ARVALID + s_axi_control.ARREADY := shell.io.host.ar.ready + shell.io.host.ar.bits.addr := s_axi_control.ARADDR + + s_axi_control.RVALID := shell.io.host.r.valid + shell.io.host.r.ready := s_axi_control.RREADY + s_axi_control.RDATA := shell.io.host.r.bits.data + s_axi_control.RRESP := shell.io.host.r.bits.resp +} diff --git a/vta/hardware/chisel/src/main/scala/test/Test.scala b/vta/hardware/chisel/src/main/scala/test/Test.scala new file mode 100644 index 000000000000..db060739147d --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/test/Test.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.test + +import chisel3._ +import vta.util.config._ +import vta.shell._ + +/** Test. This generates a testbench file for simulation */ +class Test(implicit p: Parameters) extends Module { + val io = IO(new Bundle {}) + val sim_shell = Module(new SimShell) + val vta_shell = Module(new VTAShell) + vta_shell.io.host <> sim_shell.io.host + sim_shell.io.mem <> vta_shell.io.mem +} diff --git a/vta/hardware/chisel/src/main/scala/util/Config.scala b/vta/hardware/chisel/src/main/scala/util/Config.scala new file mode 100644 index 000000000000..6699507c9f13 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/util/Config.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.util.config + +// taken from https://github.com/vta.roject/rocket-chip + +abstract class Field[T] private (val default: Option[T]) +{ + def this() = this(None) + def this(default: T) = this(Some(default)) +} + +abstract class View { + final def apply[T](pname: Field[T]): T = apply(pname, this) + final def apply[T](pname: Field[T], site: View): T = { + val out = find(pname, site) + require (out.isDefined, s"Key ${pname} is not defined in Parameters") + out.get + } + + final def lift[T](pname: Field[T]): Option[T] = lift(pname, this) + final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T]) + + protected[config] def find[T](pname: Field[T], site: View): Option[T] +} + +abstract class Parameters extends View { + final def ++ (x: Parameters): Parameters = + new ChainParameters(this, x) + + final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = + Parameters(f) ++ this + + final def alterPartial(f: PartialFunction[Any,Any]): Parameters = + Parameters((_,_,_) => f) ++ this + + final def alterMap(m: Map[Any,Any]): Parameters = + new MapParameters(m) ++ this + + protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T] + protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname) +} + +object Parameters { + def empty: Parameters = new EmptyParameters + def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f) +} + +class Config(p: Parameters) extends Parameters { + def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f)) + + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname) + override def toString = this.getClass.getSimpleName + def toInstance = this +} + +// Internal implementation: + +private class TerminalView extends View { + def find[T](pname: Field[T], site: View): Option[T] = pname.default +} + +private class ChainView(head: Parameters, tail: View) extends View { + def find[T](pname: Field[T], site: View) = head.chain(site, tail, pname) +} + +private class ChainParameters(x: Parameters, y: Parameters) extends Parameters { + def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname) +} + +private class EmptyParameters extends Parameters { + def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site) +} + +private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters { + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = { + val g = f(site, this, tail) + if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site) + } +} + +private class MapParameters(map: Map[Any, Any]) extends Parameters { + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = { + val g = map.get(pname) + if (g.isDefined) Some(g.get.asInstanceOf[T]) else tail.find(pname, site) + } +} diff --git a/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala new file mode 100644 index 000000000000..db19635c9345 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.util.genericbundle + +// taken from https://github.com/vta.roject/rocket-chip + +import chisel3._ + +abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle +{ + override def cloneType = { + try { + this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type] + } catch { + case e: java.lang.IllegalArgumentException => + throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " + + this.getClass + ", probably because " + this.getClass + + "() takes more than one argument. Consider overriding " + + "cloneType() on " + this.getClass, e) + } + } +} + diff --git a/vta/hardware/chisel/src/main/scala/vta/Configs.scala b/vta/hardware/chisel/src/main/scala/vta/Configs.scala new file mode 100644 index 000000000000..d5aa12798fe7 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/vta/Configs.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta + +import chisel3._ +import vta.util.config._ +import vta.shell._ +import vta.core._ +import vta.test._ + +/** VTA. + * + * This file contains all the configurations supported by VTA. + * These configurations are built in a mix/match form based on core + * and shell configurations. + */ + +class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig) +class DefaultF1Config extends Config(new CoreConfig ++ new F1Config) + +object DefaultPynqConfig extends App { + implicit val p: Parameters = new DefaultPynqConfig + chisel3.Driver.execute(args, () => new XilinxShell) +} + +object DefaultF1Config extends App { + implicit val p: Parameters = new DefaultF1Config + chisel3.Driver.execute(args, () => new XilinxShell) +} + +object TestDefaultF1Config extends App { + implicit val p: Parameters = new DefaultF1Config + chisel3.Driver.execute(args, () => new Test) +} diff --git a/vta/hardware/dpi/tsim_device.cc b/vta/hardware/dpi/tsim_device.cc index 08954179f1d2..0b315e4cb541 100644 --- a/vta/hardware/dpi/tsim_device.cc +++ b/vta/hardware/dpi/tsim_device.cc @@ -70,8 +70,18 @@ void VTADPIInit(VTAContextHandle handle, _mem_dpi = mem_dpi; } + +// Override Verilator finish definition +// VL_USER_FINISH needs to be defined when compiling Verilator code +void vl_finish(const char* filename, int linenum, const char* hier) { + Verilated::gotFinish(true); + VL_PRINTF("[TSIM] exiting simulation\n"); +} + int VTADPISim(uint64_t max_cycles) { uint64_t trace_count = 0; + Verilated::flushCall(); + Verilated::gotFinish(false); #if VM_TRACE uint64_t start = 0; diff --git a/vta/include/vta/driver.h b/vta/include/vta/driver.h index d583051dc194..eca9e4da9799 100644 --- a/vta/include/vta/driver.h +++ b/vta/include/vta/driver.h @@ -53,7 +53,11 @@ extern "C" { typedef void * VTADeviceHandle; /*! \brief physical address */ +#ifdef USE_TSIM +typedef uint64_t vta_phy_addr_t; +#else typedef uint32_t vta_phy_addr_t; +#endif /*! * \brief Allocate a device resource handle @@ -76,10 +80,22 @@ void VTADeviceFree(VTADeviceHandle handle); * * \return 0 if running is successful, 1 if timeout. */ +#ifdef USE_TSIM +int VTADeviceRun(VTADeviceHandle device, + vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles); +#else int VTADeviceRun(VTADeviceHandle device, vta_phy_addr_t insn_phy_addr, uint32_t insn_count, uint32_t wait_cycles); +#endif /*! * \brief Allocates physically contiguous region in memory (limited by MAX_XFER). diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index d5400d868ae4..4c2200d04727 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -239,7 +239,7 @@ def target_host(self): """The target host""" if self.TARGET == "pynq": return "llvm -target=armv7-none-linux-gnueabihf" - if self.TARGET == "sim": + if self.TARGET == "sim" or self.TARGET == "tsim": return "llvm" raise ValueError("Unknown target %s" % self.TARGET) diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py index a1e15ba69880..858e1157d8b2 100644 --- a/vta/python/vta/testing/simulator.py +++ b/vta/python/vta/testing/simulator.py @@ -17,6 +17,8 @@ """Utilities to start simulator.""" import ctypes import json +import sys +import os import tvm from ..libinfo import find_libvta @@ -55,5 +57,22 @@ def stats(): x = tvm.get_global_func("vta.simulator.profiler_status")() return json.loads(x) +def tsim_init(hw_lib): + """Init hardware shared library for TSIM + + Parameters + ------------ + hw_lib : str + Name of hardware shared library + """ + cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + vta_build_path = os.path.join(cur_path, "..", "..", "..", "build") + if not hw_lib.endswith(("dylib", "so")): + hw_lib += ".dylib" if sys.platform == "darwin" else ".so" + lib = os.path.join(vta_build_path, hw_lib) + f = tvm.get_global_func("tvm.vta.tsim.init") + m = tvm.module.load(lib, "vta-tsim") + f(m) + LIBS = _load_lib() diff --git a/vta/python/vta/testing/util.py b/vta/python/vta/testing/util.py index 48dd08588962..06c700cd7119 100644 --- a/vta/python/vta/testing/util.py +++ b/vta/python/vta/testing/util.py @@ -31,7 +31,7 @@ def run(run_func): """ env = get_env() - if env.TARGET == "sim": + if env.TARGET in ["sim", "tsim"]: # Talk to local RPC if necessary to debug RPC server. # Compile vta on your host with make at the root. @@ -48,7 +48,8 @@ def run(run_func): # Make sure simulation library exists # If this fails, build vta on host (make) # with TARGET="sim" in the json.config file. - assert simulator.enabled() + if env.TARGET == "sim": + assert simulator.enabled() run_func(env, rpc.LocalSession()) elif env.TARGET == "pynq": diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index 79a407fe521e..06b34743955f 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -56,7 +56,7 @@ struct DataBuffer { return data_; } /*! \return Physical address of the data. */ - uint32_t phy_addr() const { + vta_phy_addr_t phy_addr() const { return phy_addr_; } /*! @@ -113,7 +113,7 @@ struct DataBuffer { /*! \brief The internal data. */ void* data_; /*! \brief The physical address of the buffer, excluding header. */ - uint32_t phy_addr_; + vta_phy_addr_t phy_addr_; }; /*! @@ -302,7 +302,7 @@ class BaseQueue { return dram_buffer_; } /*! \return Physical address of DRAM. */ - uint32_t dram_phy_addr() const { + vta_phy_addr_t dram_phy_addr() const { return dram_phy_addr_; } /*! \return Whether there is pending information. */ @@ -367,7 +367,7 @@ class BaseQueue { // The buffer in DRAM char* dram_buffer_{nullptr}; // Physics address of the buffer - uint32_t dram_phy_addr_; + vta_phy_addr_t dram_phy_addr_; }; /*! @@ -424,7 +424,11 @@ class UopQueue : public BaseQueue { CHECK((dram_end_ - dram_begin_) == (sram_end_ - sram_begin_)); insn->memory_type = VTA_MEM_ID_UOP; insn->sram_base = sram_begin_; +#ifdef USE_TSIM + insn->dram_base = (uint32_t) dram_phy_addr_ + dram_begin_*kElemBytes; +#else insn->dram_base = dram_phy_addr_ / kElemBytes + dram_begin_; +#endif insn->y_size = 1; insn->x_size = (dram_end_ - dram_begin_); insn->x_stride = (dram_end_ - dram_begin_); @@ -958,7 +962,11 @@ class CommandQueue { insn->memory_type = dst_memory_type; insn->sram_base = dst_sram_index; DataBuffer* src = DataBuffer::FromHandle(src_dram_addr); +#ifdef USE_TSIM + insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type); +#else insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset; +#endif insn->y_size = y_size; insn->x_size = x_size; insn->x_stride = x_stride; @@ -981,7 +989,11 @@ class CommandQueue { insn->memory_type = src_memory_type; insn->sram_base = src_sram_index; DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr); +#ifdef USE_TSIM + insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type); +#else insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset; +#endif insn->y_size = y_size; insn->x_size = x_size; insn->x_stride = x_stride; @@ -1046,11 +1058,24 @@ class CommandQueue { // Make sure that we don't exceed contiguous physical memory limits CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER); +#ifdef USE_TSIM int timeout = VTADeviceRun( device_, insn_queue_.dram_phy_addr(), + uop_queue_.dram_phy_addr(), + inp_phy_addr_, + wgt_phy_addr_, + acc_phy_addr_, + out_phy_addr_, insn_queue_.count(), wait_cycles); +#else + int timeout = VTADeviceRun( + device_, + insn_queue_.dram_phy_addr(), + insn_queue_.count(), + wait_cycles); +#endif CHECK_EQ(timeout, 0); // Reset buffers uop_queue_.Reset(); @@ -1125,6 +1150,18 @@ class CommandQueue { ThreadLocal().reset(); } +#ifdef USE_TSIM + void SetBufPhyAddr(uint32_t type, vta_phy_addr_t addr) { + switch (type) { + case VTA_MEM_ID_INP: inp_phy_addr_ = addr; + case VTA_MEM_ID_WGT: wgt_phy_addr_ = addr; + case VTA_MEM_ID_ACC: acc_phy_addr_ = addr; + case VTA_MEM_ID_OUT: out_phy_addr_ = addr; + default: break; + } + } +#endif + private: // Push GEMM uop to the command buffer void PushGEMMOp(UopKernel* kernel) { @@ -1229,6 +1266,16 @@ class CommandQueue { InsnQueue insn_queue_; // Device handle VTADeviceHandle device_{nullptr}; +#ifdef USE_TSIM + // Input phy addr + vta_phy_addr_t inp_phy_addr_{0}; + // Weight phy addr + vta_phy_addr_t wgt_phy_addr_{0}; + // Accumulator phy addr + vta_phy_addr_t acc_phy_addr_{0}; + // Output phy addr + vta_phy_addr_t out_phy_addr_{0}; +#endif }; } // namespace vta @@ -1317,6 +1364,10 @@ void VTALoadBuffer2D(VTACommandHandle cmd, uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type) { +#ifdef USE_TSIM + vta::DataBuffer* src = vta::DataBuffer::FromHandle(src_dram_addr); + static_cast(cmd)->SetBufPhyAddr(dst_memory_type, src->phy_addr()); +#endif static_cast(cmd)-> LoadBuffer2D(src_dram_addr, src_elem_offset, x_size, y_size, x_stride, @@ -1333,6 +1384,10 @@ void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t x_size, uint32_t y_size, uint32_t x_stride) { +#ifdef USE_TSIM + vta::DataBuffer* dst = vta::DataBuffer::FromHandle(dst_dram_addr); + static_cast(cmd)->SetBufPhyAddr(src_memory_type, dst->phy_addr()); +#endif static_cast(cmd)-> StoreBuffer2D(src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, diff --git a/vta/src/tsim/tsim_driver.cc b/vta/src/tsim/tsim_driver.cc new file mode 100644 index 000000000000..e0ceb9028503 --- /dev/null +++ b/vta/src/tsim/tsim_driver.cc @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +namespace vta { +namespace tsim { + +using vta::dpi::DPIModuleNode; +using tvm::runtime::Module; + +class DPILoader { + public: + void Init(Module module) { + mod_ = module; + } + + DPIModuleNode* Get() { + return static_cast(mod_.operator->()); + } + + static DPILoader* Global() { + static DPILoader inst; + return &inst; + } + + Module mod_; +}; + +class Device { + public: + Device() { + dpi_ = DPILoader::Global(); + } + + int Run(vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + this->Init(); + this->Launch(insn_phy_addr, + uop_phy_addr, + inp_phy_addr, + wgt_phy_addr, + acc_phy_addr, + out_phy_addr, + insn_count, + wait_cycles); + this->WaitForCompletion(wait_cycles); + dev_->Finish(); + return 0; + } + + private: + void Init() { + dev_ = dpi_->Get(); + } + + void Launch(vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + // launch simulation thread + dev_->Launch(wait_cycles); + dev_->WriteReg(0x10, insn_count); + dev_->WriteReg(0x14, insn_phy_addr); + dev_->WriteReg(0x18, insn_phy_addr >> 32); + dev_->WriteReg(0x1c, 0); + dev_->WriteReg(0x20, uop_phy_addr >> 32); + dev_->WriteReg(0x24, 0); + dev_->WriteReg(0x28, inp_phy_addr >> 32); + dev_->WriteReg(0x2c, 0); + dev_->WriteReg(0x30, wgt_phy_addr >> 32); + dev_->WriteReg(0x34, 0); + dev_->WriteReg(0x38, acc_phy_addr >> 32); + dev_->WriteReg(0x3c, 0); + dev_->WriteReg(0x40, out_phy_addr >> 32); + // start + dev_->WriteReg(0x00, 0x1); + } + + void WaitForCompletion(uint32_t wait_cycles) { + uint32_t i, val; + for (i = 0; i < wait_cycles; i++) { + val = dev_->ReadReg(0x00); + val &= 0x2; + if (val == 0x2) break; // finish + } + } + + DPILoader* dpi_; + DPIModuleNode* dev_; +}; + +using tvm::runtime::TVMRetValue; +using tvm::runtime::TVMArgs; + +TVM_REGISTER_GLOBAL("tvm.vta.tsim.init") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Module m = args[0]; + DPILoader::Global()->Init(m); + }); + +} // namespace tsim +} // namespace vta + +void* VTAMemAlloc(size_t size, int cached) { + void *p = malloc(size); + return p; +} + +void VTAMemFree(void* buf) { + free(buf); +} + +vta_phy_addr_t VTAMemGetPhyAddr(void* buf) { + return reinterpret_cast(reinterpret_cast(buf)); +} + +void VTAFlushCache(vta_phy_addr_t buf, int size) { +} + +void VTAInvalidateCache(vta_phy_addr_t buf, int size) { +} + +VTADeviceHandle VTADeviceAlloc() { + return new vta::tsim::Device(); +} + +void VTADeviceFree(VTADeviceHandle handle) { + delete static_cast(handle); +} + +int VTADeviceRun(VTADeviceHandle handle, + vta_phy_addr_t insn_phy_addr, + vta_phy_addr_t uop_phy_addr, + vta_phy_addr_t inp_phy_addr, + vta_phy_addr_t wgt_phy_addr, + vta_phy_addr_t acc_phy_addr, + vta_phy_addr_t out_phy_addr, + uint32_t insn_count, + uint32_t wait_cycles) { + return static_cast(handle)->Run( + insn_phy_addr, + uop_phy_addr, + inp_phy_addr, + wgt_phy_addr, + acc_phy_addr, + out_phy_addr, + insn_count, + wait_cycles); +} diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py index 58835bbe3eab..2cedceae4e7d 100644 --- a/vta/tests/python/unittest/test_vta_insn.py +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -68,6 +68,10 @@ def _run(env, remote): y_np = x_np.astype(y.dtype) x_nd = tvm.nd.array(x_np, ctx) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(x_nd, y_nd) np.testing.assert_equal(y_np, y_nd.asnumpy()) @@ -126,6 +130,10 @@ def _run(env, remote): :] = x_np x_nd = tvm.nd.array(x_np, ctx) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(x_nd, y_nd) np.testing.assert_equal(y_np, y_nd.asnumpy()) @@ -197,6 +205,9 @@ def verify(s): y_np = np.right_shift(y_np, 8) y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype) + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + if env.TARGET == "sim": simulator.clear_stats() f(x_nd, w_nd, y_nd) @@ -351,6 +362,10 @@ def check_alu(tvm_op, np_op=None, use_imm=False): a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + if use_imm: f(a_nd, res_nd) else: @@ -420,6 +435,10 @@ def _run(env, remote): a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(a_nd, res_nd) np.testing.assert_equal(res_np, res_nd.asnumpy()) @@ -479,6 +498,10 @@ def _run(env, remote): a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) + + if env.TARGET == "tsim": + simulator.tsim_init("libvta_hw") + f(a_nd, res_nd) np.testing.assert_equal(res_np, res_nd.asnumpy()) @@ -503,11 +526,12 @@ def _run(env, remote): print("Load/store test") test_save_load_out() print("Padded load test") - #test_padded_load() + test_padded_load() print("GEMM test") test_gemm() - test_alu() print("ALU test") + test_alu() + print("Relu test") test_relu() print("Shift and scale") test_shift_and_scale() From 9851c8398066e4881ccd3b276ce231506035e180 Mon Sep 17 00:00:00 2001 From: Ramana Radhakrishnan Date: Wed, 5 Jun 2019 18:19:13 +0100 Subject: [PATCH 081/176] Add support for overloading comparison operations in relay (#2910) (#3168) --- python/tvm/relay/expr.py | 32 +++++++++++++++++++++++++++ tests/python/relay/test_cmp_op.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/python/relay/test_cmp_op.py diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 98b4a83e09de..8e7f95c4dc26 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -70,6 +70,38 @@ def astype(self, dtype): def __neg__(self): return _op_make.negative(self) + def __lt__(self, other): + if isinstance(other, Expr): + return _op_make.less(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __gt__(self, other): + if isinstance(other, Expr): + return _op_make.greater(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __ge__(self, other): + if isinstance(other, Expr): + return _op_make.greater_equal(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __le__(self, other): + if isinstance(other, Expr): + return _op_make.less_equal(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + def __add__(self, other): if isinstance(other, Expr): return _op_make.add(self, other) diff --git a/tests/python/relay/test_cmp_op.py b/tests/python/relay/test_cmp_op.py new file mode 100644 index 000000000000..d096eec598b7 --- /dev/null +++ b/tests/python/relay/test_cmp_op.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm import relay +a = relay.Var("a") +b = relay.expr.const (1.0, dtype='float32') + +c = a < b +d = relay.less (a, b) +assert (c.astext() == d.astext()) + +c = a > b +d = relay.greater (a, b) +assert (c.astext() == d.astext()) + +c = (a >= b) +d = relay.greater_equal(a, b) +assert (c.astext() == d.astext()) + +c = (a <= b) +d = relay.less_equal(a, b) +assert (c.astext() == d.astext()) From b130c9ceb7fa07b88710c416c62f17a9b0d8d365 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Wed, 5 Jun 2019 16:23:11 -0700 Subject: [PATCH 082/176] fast tanh (#3255) --- topi/include/topi/elemwise.h | 74 +++++++++++++++++++++++++++-- topi/tests/python/test_topi_math.py | 30 +++++++----- 2 files changed, 90 insertions(+), 14 deletions(-) diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index a9f8f630471f..b3681e17da8d 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,6 +31,7 @@ #include "tvm/tvm.h" #include "tvm/ir.h" #include "tvm/ir_pass.h" +#include "broadcast.h" namespace topi { using namespace tvm; @@ -46,7 +47,6 @@ using namespace tvm; } TOPI_DECLARE_UNARY_OP(exp); -TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(log); @@ -56,6 +56,74 @@ TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(abs); +/* + * \brief Fast_tanh_float implementation from Eigen + * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26 + */ +inline Tensor fast_tanh_float(const Tensor& in, + std::string name, + std::string tag) { + // Clamp the inputs to the range [-9, 9] since anything outside + // this range is +/-1.0f in single-precision. + auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0)); + + // The monomial coefficients of the numerator polynomial (odd). + auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03); + auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04); + auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05); + auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08); + auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11); + auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13); + auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16); + + // The monomial coefficients of the denominator polynomial (even). + auto beta_0 = make_const(in->dtype, 4.89352518554385e-03); + auto beta_2 = make_const(in->dtype, 2.26843463243900e-03); + auto beta_4 = make_const(in->dtype, 1.18534705686654e-04); + auto beta_6 = make_const(in->dtype, 1.19825839466702e-06); + + return compute(x->shape, + [&](const Array& i) { + auto x2 = x(i) * x(i); + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x(i) * p; + + auto q = x2 * beta_6 + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + return p / q; + }, + name, tag); +} + +/*! +* \brief Creates an operation that returns hyperbolic tanh of a given tensor +* +* \param x The input tensor +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is tanh +*/ +inline Tensor tanh(const Tensor& x, + std::string name = "T_tanh", + std::string tag = kElementWise) { + if (x->dtype == Float(32)) { + // invoke fast_tanh_float implementation + return fast_tanh_float(x, name, tag); + } else { + // fallback to default implementation + return compute(x->shape, [&](const Array& i) { + return ::tvm::tanh(x(i)); + }, name, tag); + } +} + /*! * \brief Creates an operation that returns identity of a given tensor * diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index d6df450628d2..a276f12b27f0 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -29,13 +29,21 @@ def test_util(): def test_ewise(): - m = tvm.var('m') - l = tvm.var('l') - A = tvm.placeholder((m, l), name='A') + def test_apply( + func, + name, + f_numpy, + low, + high, + shape=(20, 3), + dtype=tvm.float32, + check_round=False, + skip_name_check=False, + ): + m = tvm.var("m") + l = tvm.var("l") + A = tvm.placeholder((m, l), dtype=dtype, name="A") - shape = (20, 3) - - def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False): B = func(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: @@ -63,7 +71,6 @@ def check_device(device): for device in get_all_backend(): check_device(device) - test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True) @@ -71,11 +78,12 @@ def check_device(device): test_apply(topi.abs, "fabs", np.abs, -100, 100) test_apply(topi.round, "round", np.round, -100, 100, check_round=True) test_apply(topi.exp, "exp", np.exp, -1, 1) - test_apply(topi.tanh, "tanh", np.tanh, -10, 10) - test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1) + test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128)) + test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128), dtype="float64") + test_apply(topi.sigmoid, "sigmoid", lambda x: 1 / (1 + np.exp(-x)), -1, 1) test_apply(topi.log, "log", np.log, 0, 100) test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) - test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True) + test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True) def test_cast(): @@ -93,7 +101,7 @@ def verify(from_dtype, to_dtype, low=-100, high=100): b_np = a_np.astype(to_dtype) for device in get_all_backend(): - ctx = tvm.context(device, 0) + ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) continue From d1325be02e1a4fd7aef8e16ed70ed37e84ab5e8e Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 5 Jun 2019 16:27:16 -0700 Subject: [PATCH 083/176] Ghost nodes in NNVM graph (#3290) --- nnvm/include/nnvm/op_attr_types.h | 11 +++++++++++ nnvm/src/core/graph.cc | 3 +++ 2 files changed, 14 insertions(+) diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 976ad929f496..ad328c30312a 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -136,6 +136,17 @@ using FInferType = FInferNodeEntryAttr; */ using TIsBackward = bool; +/*! + * \brief Whether this op is a ghost node. + * If TIsGhost is true: + * - The node with this op will not be visible in the indexed graph. + * + * \note Register under "TIsGhost" + * This enables shape/type inference for backward nodes when + * fusion is present. + */ +using TIsGhost = bool; + /*! * \brief Get possible inplace options. * This function enables optimization to reuse memory of inputs in output. diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 92ff98618ec8..29149f48fdb0 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] (const NodePtr& n) { + const auto& is_ghost = Op::GetAttr("TIsGhost"); + if (!n->is_variable() && is_ghost.get(n->op(), false)) return; CHECK_LT(nodes_.size(), std::numeric_limits::max()); uint32_t nid = static_cast(nodes_.size()); CHECK(n); @@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { inputs_rptr.push_back(input_entries_.size()); // control deps for (const auto& nptr : n->control_deps) { + if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; auto it = node2index_.find(nptr.get()); CHECK(it != node2index_.end() && it->first == nptr.get()); control_deps_.push_back(it->second); From ceac70b64c788c8886a5d321713f32a7342e5f98 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 5 Jun 2019 21:42:20 -0700 Subject: [PATCH 084/176] Improve x86 roi align (#3296) * Improve roi_align performance for x86 * Change test --- topi/python/topi/x86/__init__.py | 1 + topi/python/topi/x86/roi_align.py | 217 ++++++++++++++++++++++++++ topi/tests/python/test_topi_vision.py | 1 + 3 files changed, 219 insertions(+) create mode 100644 topi/python/topi/x86/roi_align.py diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index cce816d43ba1..efc1bc512285 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -13,3 +13,4 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack from .batch_matmul import schedule_batch_matmul +from .roi_align import roi_align_nchw diff --git a/topi/python/topi/x86/roi_align.py b/topi/python/topi/x86/roi_align.py new file mode 100644 index 000000000000..a8ad387a242f --- /dev/null +++ b/topi/python/topi/x86/roi_align.py @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements +"""Non-maximum suppression operator for intel cpu""" +import tvm + +from tvm import hybrid +from ..vision.rcnn import roi_align_nchw + + +@hybrid.script +def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): + """Hybrid routing fo ROI align operator in NCHW layout. + + Parameters + ---------- + data : tvm.Tensor or numpy NDArray + 4-D with shape [batch, channel, height, width] + + rois : tvm.Tensor or numpy NDArray + 2-D with shape [num_roi, 5]. The last dimension should be in format of + [batch_index, w_start, h_start, w_end, h_end] + + pooled_size : tvm ConsExpr + [out_height, out_width] + + spatial_scale : tvm.const + Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal + of total stride in convolutional layers, which should be in range (0.0, 1.0] + + sample_ratio : tvm.const + Sampling ratio of ROI align, using adaptive size by default. + + Returns + ------- + output : tvm.Tensor or numpy NDArray + 4-D with shape [num_roi, channel, pooled_size, pooled_size] + """ + channels = data.shape[1] + height = data.shape[2] + width = data.shape[3] + num_rois = rois.shape[0] + pooled_size_h = pooled_size[0] + pooled_size_w = pooled_size[1] + output = output_tensor((num_rois, channels, pooled_size_h, pooled_size_w), data.dtype) + max_num_pc_index = height * width * pooled_size_h * pooled_size_w + w_pc = allocate((num_rois, max_num_pc_index, 4), data.dtype) + pos_pc = allocate((num_rois, max_num_pc_index, 4), "int32") + + for n in parallel(num_rois): + roi_batch_index = int32(rois[n, 0]) + roi_start_w = rois[n, 1] * spatial_scale + roi_start_h = rois[n, 2] * spatial_scale + roi_end_w = rois[n, 3] * spatial_scale + roi_end_h = rois[n, 4] * spatial_scale + + roi_h = max(roi_end_h - roi_start_h, 1.0) + roi_w = max(roi_end_w - roi_start_w, 1.0) + + bin_h = roi_h / pooled_size_h + bin_w = roi_w / pooled_size_w + + roi_bin_grid_h = sample_ratio + roi_bin_grid_w = roi_bin_grid_h + div_h = roi_h / pooled_size_h + div_w = roi_w / pooled_size_w + rounded_div_h = int32(div_h) * 1.0 + rounded_div_w = int32(div_w) * 1.0 + if sample_ratio <= 0: + # Cannot use ceil function since hybrid script + # doesn't support Call as indexing + roi_bin_grid_h = int32(div_h) + roi_bin_grid_w = int32(div_w) + if rounded_div_h < div_h: + roi_bin_grid_h += 1 + if rounded_div_w < div_w: + roi_bin_grid_w += 1 + + count = roi_bin_grid_h * roi_bin_grid_w + + # Pre-calculate indices and weights shared by all channels. + # This is the key point of optimization. + pre_calc_index = 0 + iy_upper = roi_bin_grid_h + ix_upper = roi_bin_grid_w + for ph in range(pooled_size_h): + for pw in range(pooled_size_w): + for iy in range(iy_upper): + yy = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h + for ix in range(ix_upper): + xx = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w + x = xx + y = yy + if y < -1.0 or y > height or x < -1.0 or x > width: + for i in range(4): + w_pc[n, pre_calc_index, i] = 0.0 + pos_pc[n, pre_calc_index, i] = 0 + else: + if y < 0.0: + y = 0.0 + if x < 0.0: + x = 0.0 + + y_low = int32(y) + x_low = int32(x) + x_high = x_low + 1 + y_high = y_low + 1 + + if y_low >= height - 1: + y_high = height - 1 + y_low = y_high + y = float32(y_low) + + if x_low >= width - 1: + x_high = width - 1 + x_low = x_high + x = float32(x_low) + + ly = y - y_low + lx = x - x_low + hy = 1.0 - ly + hx = 1.0 - lx + w1 = hy * hx + w2 = hy * lx + w3 = ly * hx + w4 = ly * lx + + pos_pc[n, pre_calc_index, 0] = x_low + pos_pc[n, pre_calc_index, 1] = x_high + pos_pc[n, pre_calc_index, 2] = y_low + pos_pc[n, pre_calc_index, 3] = y_high + w_pc[n, pre_calc_index, 0] = w1 + w_pc[n, pre_calc_index, 1] = w2 + w_pc[n, pre_calc_index, 2] = w3 + w_pc[n, pre_calc_index, 3] = w4 + + pre_calc_index += 1 + + for c in range(channels): + pre_calc_index = 0 + for ph in range(pooled_size_h): + for pw in range(pooled_size_w): + output_val = 0.0 + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + output_val += w_pc[n, pre_calc_index, 0] \ + * data[roi_batch_index, c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 0]] \ + + w_pc[n, pre_calc_index, 1] \ + * data[roi_batch_index, c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 1]] \ + + w_pc[n, pre_calc_index, 2] \ + * data[roi_batch_index, c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 0]] \ + + w_pc[n, pre_calc_index, 3] \ + * data[roi_batch_index, c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 1]] + pre_calc_index += 1 + + output_val /= count + output[n, c, ph, pw] = output_val + + return output + + +@roi_align_nchw.register("cpu") +def roi_align_nchw_cpu(data, rois, pooled_size, spatial_scale, sample_ratio=-1): + """ROI align operator in NCHW layout. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, channel, height, width] + + rois : tvm.Tensor + 2-D with shape [num_roi, 5]. The last dimension should be in format of + [batch_index, w_start, h_start, w_end, h_end] + + pooled_size : int or list/tuple of two ints + output size, or [out_height, out_width] + + spatial_scale : float + Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal + of total stride in convolutional layers, which should be in range (0.0, 1.0] + + sample_ratio : int + Optional sampling ratio of ROI align, using adaptive size by default. + + Returns + ------- + output : tvm.Tensor + 4-D with shape [num_roi, channel, pooled_size, pooled_size] + """ + if not isinstance(pooled_size, (tuple, list)): + pooled_size = (pooled_size, pooled_size) + pooled_size = tvm.convert(pooled_size) + spatial_scale = tvm.const(spatial_scale, "float32") + sample_ratio = tvm.const(sample_ratio, "int32") + return roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 483f3a641c70..54c80c6e8c30 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -282,6 +282,7 @@ def check_device(device): def test_roi_align(): verify_roi_align(1, 16, 32, 64, 7, 1.0, -1) verify_roi_align(4, 16, 32, 64, 7, 0.5, 2) + verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2) def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale): From 265a71a7cd9279a5cda4d267eb104d7b75a55cd9 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Wed, 5 Jun 2019 22:03:12 -0700 Subject: [PATCH 085/176] [VTA] [APPS] [TSIM] small naming fix (#3293) * make off lowercase * update README --- vta/apps/tsim_example/README.md | 5 ++--- vta/apps/tsim_example/cmake/modules/hw.cmake | 8 ++++---- vta/apps/tsim_example/config/config.json | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vta/apps/tsim_example/README.md b/vta/apps/tsim_example/README.md index dc06a92f2b0e..e19d6bbe7ef0 100644 --- a/vta/apps/tsim_example/README.md +++ b/vta/apps/tsim_example/README.md @@ -49,8 +49,7 @@ sudo apt install verilator sbt ## Setup in TVM 1. Install `verilator` and `sbt` as described above -2. Set the VTA TARGET to `tsim` on `/vta/config/vta_config.json` -3. Build tvm +2. Build tvm ## How to run VTA TSIM examples @@ -60,7 +59,7 @@ These examples are located at `/vta/apps/tsim_example`. * Instructions * Open `/vta/apps/tsim_example/python/tsim/config.json` * Change `TARGET` from `verilog` to `chisel`, depending on what language backend you would like to test - * Go to `tvm/vta/apps/tsim` + * Go to `tvm/vta/apps/tsim_example` * Run `make` * Some pointers diff --git a/vta/apps/tsim_example/cmake/modules/hw.cmake b/vta/apps/tsim_example/cmake/modules/hw.cmake index e016ea03b6fa..019be129f243 100644 --- a/vta/apps/tsim_example/cmake/modules/hw.cmake +++ b/vta/apps/tsim_example/cmake/modules/hw.cmake @@ -87,7 +87,7 @@ else() if (TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog") # Check if tracing can be enabled - if (NOT TSIM_USE_TRACE STREQUAL "OFF") + if (NOT TSIM_USE_TRACE STREQUAL "off") message(STATUS "[TSIM_HW] Verilog enable tracing") else() message(STATUS "[TSIM_HW] Verilator disable tracing") @@ -101,7 +101,7 @@ else() list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_BUILD_DIR}) list(APPEND VERILATOR_OPT --cc ${VERILATOR_RTL_SRC}) - if (NOT TSIM_USE_TRACE STREQUAL "OFF") + if (NOT TSIM_USE_TRACE STREQUAL "off") list(APPEND VERILATOR_OPT --trace) endif() @@ -116,7 +116,7 @@ else() set(VERILATOR_INC_DIR /usr/local/share/verilator/include) set(VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated.cpp ${VERILATOR_INC_DIR}/verilated_dpi.cpp) - if (NOT TSIM_USE_TRACE STREQUAL "OFF") + if (NOT TSIM_USE_TRACE STREQUAL "off") list(APPEND VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated_vcd_c.cpp) endif() @@ -125,7 +125,7 @@ else() add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC}) set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) - if (NOT TSIM_USE_TRACE STREQUAL "OFF") + if (NOT TSIM_USE_TRACE STREQUAL "off") list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd) else() list(APPEND VERILATOR_DEF VM_TRACE=0) diff --git a/vta/apps/tsim_example/config/config.json b/vta/apps/tsim_example/config/config.json index 5f9ee69904fd..887eaac67d74 100644 --- a/vta/apps/tsim_example/config/config.json +++ b/vta/apps/tsim_example/config/config.json @@ -2,6 +2,6 @@ "TARGET" : "verilog", "TOP_NAME" : "TestAccel", "BUILD_NAME" : "build", - "USE_TRACE" : "OFF", + "USE_TRACE" : "off", "TRACE_NAME" : "trace" } From 21ea7e6984874fc254d6430b2f3f37010d45d295 Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Thu, 6 Jun 2019 21:00:19 +0300 Subject: [PATCH 086/176] [Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993) --- python/tvm/relay/frontend/tensorflow.py | 190 +++++++++--------- .../frontend/tensorflow/test_forward.py | 78 ++++--- topi/python/topi/util.py | 12 +- 3 files changed, 140 insertions(+), 140 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 307fb20693f4..f709a63e79e8 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -63,7 +63,7 @@ def _get_relay_op(op_name): return op class AttrCvt(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -222,17 +222,37 @@ def _dim_check(attrs): return False return _dim_check, "Only 2d kernel supported." -def _infer_channels(inputs, params, transpose=False): - """A hack for getting 'channles' or 'units' since tensorflow don't provide +def _infer_channels(node, params, transpose=False): + """A hack for getting 'channels' or 'units' since tensorflow don't provide these attributes. We check the shape of weights provided to get the number. """ - out_type = ir_pass.infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - channels = out_shapes[0][0] if not transpose else out_shapes[0][1] + out_shape = _infer_shape(node, params) + channels = out_shape[0] if not transpose else out_shape[1] return channels +def _infer_out_shapes(inputs, params): + """A method to get the output shape of intermediate nodes in the relay graph.""" + return [_infer_shape(inputs, params)] + +def _infer_shape(node, params=None): + """A method to get the output shape of an intermediate node in the relay graph.""" + out_type = ir_pass.infer_type(node) + return get_const_tuple(out_type.checked_type.shape) + +def _get_param(params, input_node): + return params.pop(input_node.name_hint).asnumpy() + +def _get_num_param(params, input_node): + return _get_param(params, input_node)[0] + +def _get_list_param(params, input_node): + return _get_param(params, input_node).tolist() + +def _get_tuple_param(params, input_node): + return tuple(_get_param(params, input_node)) + def _rsqrt(): - def _impl(inputs, attr, *args): + def _impl(inputs, attr, params): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) return AttrCvt(op_name="power")(inputs, attr) return _impl @@ -243,16 +263,15 @@ def _impl(inputs, attr, params): try: # In Tensorflow, `axis` argument is a Tensor, not attribute. We # support the case where it inputs from a scalar constant. - axis_input_name = inputs[1].name_hint - axis_input_vlaue = [params[axis_input_name].asnumpy()[0]] + axis_input_value = [_get_num_param(params, inputs[1])] except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) - return func(inputs[0], axis=axis_input_vlaue, keepdims=False) + return func(inputs[0], axis=axis_input_value, keepdims=False) return _impl def _elemwise(name): - def _impl(inputs, attr, *args): + def _impl(inputs, attr, params): assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) return _get_relay_op(name)(*inputs) return _impl @@ -472,7 +491,7 @@ def _impl(inputs, attr, params): def _expand_dims(): def _impl(inputs, attr, params): dim_input = inputs.pop(1) - axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0] + axis = _get_num_param(params, dim_input) return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr) return _impl @@ -527,21 +546,19 @@ def _impl(inputs, attr, params): def _concatV2(): def _impl(inputs, attr, params): pop_node = inputs.pop(len(inputs)-1) - axis = params[pop_node.name_hint] - params.pop(pop_node.name_hint) + axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['T', 'N', 'Tidx'], - extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + extras={'axis': axis})([inputs], attr) return _impl def _concat(): def _impl(inputs, attr, params): pop_node = inputs.pop(0) - axis = params[pop_node.name_hint] - params.pop(pop_node.name_hint) + axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['N'], - extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + extras={'axis': axis})([inputs], attr) return _impl def _pack(): @@ -565,8 +582,8 @@ def _impl(inputs, attr, params): def _slice(): def _impl(inputs, attr, params): - begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist() - size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist() + begin = _get_list_param(params, inputs[1]) + size = _get_list_param(params, inputs[2]) data_shape = attr['_input_shapes'][inputs[0]] data_dim = len(data_shape) end = size @@ -581,24 +598,18 @@ def _impl(inputs, attr, params): def _reshape(): def _impl(inputs, attr, params): + pop_node = inputs.pop(1) try: - pop_node = inputs[1] - shape_arg = params.pop(pop_node.name_hint) - inputs.pop(1) - - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(shape_arg.asnumpy())}, - ignores=['Tshape'])(inputs, attr) + shape_arg = _get_tuple_param(params, pop_node) except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - params_new = _infer_value(inputs[1], params) - inputs.pop(1) - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())}, - ignores=['Tshape'])(inputs, attr) + params_new = _infer_value(pop_node, params) + shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + return AttrCvt( + op_name="reshape", + extras={'newshape': shape_arg}, + ignores=['Tshape'])(inputs, attr) return _impl @@ -737,9 +748,10 @@ def _impl(inputs, attr, params): if -1 in output_shape: output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() - fill_arg = params.pop(inputs.pop(1).name_hint) - return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), - output_shape, attr['T'].name) + fill_arg = _get_num_param(params, inputs.pop(1)) + dtype = attr['T'].name + return _op.full(tvm.relay.const(fill_arg, dtype), + output_shape, dtype) return _impl def _lrn(): @@ -757,9 +769,7 @@ def _impl(inputs, attr, params): def _sum(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy() - # convert to tuple for preventing invalid parameter format error - axis = tuple(axis) + axis = _get_tuple_param(params, inputs[1]) return AttrCvt( op_name='sum', extras={'axis': axis}, @@ -786,25 +796,17 @@ def _impl(inputs, attr, params): def _gather(): "GatherV2, Gather" def _impl(inputs, attr, params): - - axis = 0 if len(inputs) > 2: - axis = params[inputs.pop(2).name_hint].asnumpy()[0] - new_input = [] - new_input.append(inputs.pop(0)) - new_input.append(inputs.pop(0)) + axis = _get_num_param(params, inputs.pop(2)) + else: + axis = 0 + new_input = inputs[0:2] return AttrCvt(op_name="take", extras={'axis': tvm.const(axis, 'int32')}, - ignores=['Tindices', 'Tparams', 'validate_indices', \ + ignores=['Tindices', 'Tparams', 'validate_indices', 'Taxis', '_class'])(new_input, attr) return _impl -def _infer_out_shapes(inputs, params): - """A method to get the output shape of an intermediate node in the relay graph.""" - out_type = ir_pass.infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - return out_shapes - def _stridedSlice(): def _impl(inputs, attr, params): """Strided Slice. @@ -812,9 +814,9 @@ def _impl(inputs, attr, params): Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ tensorflow/core/util/strided_slice_op.cc#L147-L368 """ - begin = params.pop(inputs[1].name_hint).asnumpy().tolist() - end = params.pop(inputs[2].name_hint).asnumpy().tolist() - stride = params.pop(inputs[3].name_hint).asnumpy().tolist() + begin = _get_list_param(params, inputs[1]) + end = _get_list_param(params, inputs[2]) + stride = _get_list_param(params, inputs[3]) begin_mask = int(attr.get('begin_mask', 0)) end_mask = int(attr.get('end_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0)) @@ -889,7 +891,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_out_shapes(out, params)[0] + out_shape = _infer_shape(out, params) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -910,19 +912,14 @@ def _transform_mask(stride_dim, ellipsis_mask): def _pad(name): def _impl(inputs, attr, params): - padlist_key = inputs[1].name_hint - if padlist_key in params: - padlist = params.pop(padlist_key).asnumpy() - else: - raise tvm.error.OpAttributeRequired( - 'Attribute {} not found in operator Pad.'.format(padlist_key)) - paddings = tuple([tuple(l) for l in padlist]) + padlist = _get_param(params, inputs[1]) + paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings attr['pad_value'] = 0 new_inputs = [inputs[0]] if name == 'PadV2': - constant_values = params.pop(inputs[2].name_hint).asnumpy() - attr['pad_value'] = constant_values[0] + constant_values = _get_num_param(params, inputs[2]) + attr['pad_value'] = constant_values return AttrCvt( op_name='pad', ignores=['Tpaddings'],)(new_inputs, attr) @@ -932,10 +929,9 @@ def _transpose(): def _impl(inputs, attr, params): # If perm is not specified, axes is left empty, # otherwise its value is get from params - param_name = _get_name_hint(inputs[1]) - if param_name in params: - axes = tuple(params.get(param_name).asnumpy()) - else: + try: + axes = _get_list_param(params, inputs[1]) + except (IndexError, KeyError): axes = None return _op.transpose(inputs[0], axes=axes) return _impl @@ -947,7 +943,7 @@ def _impl(inputs, attr, params): def _reverse_v2(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy()[0] + axis = _get_num_param(params, inputs[1]) return AttrCvt( op_name="reverse", ignores=['Tidx'], @@ -968,9 +964,9 @@ def _impl(inputs, attr, params): def _range(): def _impl(inputs, attr, params): - start = params.pop(inputs[0].name_hint).asnumpy()[0] - limit = params.pop(inputs[1].name_hint).asnumpy()[0] - delta = params.pop(inputs[2].name_hint).asnumpy()[0] + start = _get_num_param(params, inputs[0]) + limit = _get_num_param(params, inputs[1]) + delta = _get_num_param(params, inputs[2]) name = attr["_node_name"] params[name] = tvm.nd.array([start, limit, delta]) @@ -981,25 +977,27 @@ def _impl(inputs, attr, params): def _elu(): def _impl(inputs, attr, params): - alpha = tvm.relay.const(-1.0, attr['T'].name) - return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + dtype = attr['T'].name + alpha = tvm.relay.const(-1.0, dtype) + return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) return _impl def _selu(): def _impl(inputs, attr, params): - alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name) - gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name) - return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + dtype = attr['T'].name + alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) + gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) + return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) return _impl def _mean(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint) + axis = _get_tuple_param(params, inputs[1]) return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], transforms={'keep_dims': 'keepdims'}, - extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr) + extras={'axis': axis})([inputs[0]], attr) return _impl def _broadcast(name): @@ -1025,8 +1023,7 @@ def _impl(inputs, attr, params): if has_size_vector: input_node_index = 0 input_axis_index = 2 - size_splits_input_name = _get_name_hint(inputs[1]) - size_splits = params[size_splits_input_name].asnumpy() + size_splits = _get_param(params, inputs[1]) section_beginnings = np.cumsum(size_splits)[:-1] indices_or_sections = tuple(section_beginnings) else: @@ -1034,8 +1031,7 @@ def _impl(inputs, attr, params): input_axis_index = 0 indices_or_sections = attr['num_split'] input_node = inputs[input_node_index] - axis_input_name = _get_name_hint(inputs[input_axis_index]) - axis_input_value = params[axis_input_name].asnumpy()[0] + axis_input_value = _get_num_param(params, inputs[input_axis_index]) except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for split: `axis` and `num_or_size_splits` " \ @@ -1105,8 +1101,8 @@ def _space_to_batch_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] - block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() - paddings = params.pop(inputs[2].name_hint).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1]) + paddings = _get_list_param(params, inputs[2]) N = len(input_shape) M = len(block_shape) batch = input_shape[0] @@ -1127,7 +1123,7 @@ def _impl(inputs, attr, params): axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0] + permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., @@ -1144,8 +1140,8 @@ def _batch_to_space_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] - block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() - crops = params.pop(inputs[2].name_hint).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1]) + crops = _get_list_param(params, inputs[2]) M = len(block_shape) batch = input_shape[0] # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: @@ -1170,7 +1166,7 @@ def _impl(inputs, attr, params): # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0] + reshaped_permuted_shape = _infer_shape(reshaped_permuted, params) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] @@ -1971,23 +1967,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Infer shapes even without specifying "add_shapes=True" if output_shapes == [None]: - out_shapes = [] - for node_item in self._nodes[node.name]: - out_type = ir_pass.infer_type(node_item) - out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]] self._output_shapes[node.name] = out_shapes if self._output_shapes[node.name] and shape and node.name in shape: assert self._output_shapes[node.name] == list(shape[node.name]) - # Infer shapes if passed explicitely + # Infer shapes if passed explicitly node_output = self._nodes[node.name] if shape and (not self._output_shapes[node.name][0] or -1 in self._output_shapes[node.name][0]): - out_shapes = [] - for node_item in node_output: - out_type = ir_pass.infer_type(node_item) - out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + out_shapes = [_infer_shape(node_item) for node_item in node_output] self._output_shapes[node.name] = out_shapes out = [] diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index eebb73c95b1b..3899bc04d5c6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, layout = None if target == "cuda": layout = "NCHW" - target_host = 'llvm' - - if isinstance(input_data, list): - shape_dict = {} - dtype_dict = {} - for i, e in enumerate(input_node): - shape_dict[e] = input_data[i].shape - dtype_dict[e] = input_data[i].dtype - else: - shape_dict = {input_node: input_data.shape} - dtype_dict = {input_node: input_data.dtype} + target_host = None + + shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=out_names) with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build(sym, target, params=params) + graph, lib, params = relay.build(sym, target, target_host, params) ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs - for i, e in enumerate(input_node): - m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + for e, i in zip(input_node, input_data): + m.set_input(e, tvm.nd.array(i)) m.set_input(**params) # execute @@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, # get outputs assert out_names is None or num_output == len(out_names), ( "out_names: {} num_output: {}".format(out_names, num_output)) - tvm_output_list = [] - for i in range(0, num_output): - tvm_output = m.get_output(i) - tvm_output_list.append(tvm_output.asnumpy()) + tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)] return tvm_output_list def run_tf_graph(sess, input_data, input_node, output_node): @@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node): input_node = convert_to_list(input_node) output_node = convert_to_list(output_node) - tensor = [0] * len(output_node) - for i in range(len(output_node)): - tensor[i] = sess.graph.get_tensor_by_name(output_node[i]) + tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node] - input_dict = {} - for i, e in enumerate(input_node): - input_dict[e] = input_data[i] + input_dict = {e: input_data[i] for i, e in enumerate(input_node)} output_data = sess.run(tensor, input_dict) return output_data @@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node): def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False, opt_level=3): """Generic function to generate and compare tensorflow and TVM output""" + def name_without_num(name): + return name.split(':')[0] if ":" in name else name out_name = convert_to_list(out_name) - out_node = [0]*len(out_name) - for i in range(len(out_name)): - out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i] + out_node = [name_without_num(name) for name in out_name] in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) - in_node = [0]*len(in_name) - for i in range(len(in_name)): - in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i] + in_node = [name_without_num(name) for name in in_name] with tf.Session() as sess: if init_global_variables: sess.run(variables.global_variables_initializer()) @@ -577,6 +560,38 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) +####################################################################### +# MatMul +# ------ + +def _test_matmul(i, j, k, dtype, outer=None): + """ One iteration of matmul """ + + A_shape_init = [i, j] + B_shape_init = [j, k] + + for transpose_a in [False, True]: + for transpose_b in [False, True]: + outer = outer or [] + A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init) + B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init) + + with tf.Graph().as_default(): + A = tf.placeholder(shape=A_shape, dtype=dtype, name='A') + B = tf.placeholder(shape=B_shape, dtype=dtype, name='B') + result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b) + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + +def test_forward_matmul(): + """ Matmul op test""" + _test_matmul(1, 3, 6, 'int32') + _test_matmul(5, 3, 1, 'float64') + # TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support + + ####################################################################### # StridedSlice # ------------ @@ -1785,3 +1800,6 @@ def test_placeholder(): test_forward_rel_ops() test_forward_logical() test_where() + + test_forward_matmul() + # TODO missing tests: rank, range \ No newline at end of file diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index f648245c6bb7..623c81a07da8 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -151,11 +151,7 @@ def get_const_tuple(in_tuple): out_tuple : tuple of int The output. """ - out_tuple = () - for elem in in_tuple: - value = get_const_int(elem) - out_tuple = out_tuple + (value, ) - return out_tuple + return tuple(get_const_int(elem) for elem in in_tuple) def get_float_tuple(in_tuple): @@ -171,11 +167,7 @@ def get_float_tuple(in_tuple): out_tuple : tuple of float The output. """ - out_tuple = () - for elem in in_tuple: - value = get_const_float(elem) - out_tuple = out_tuple + (value, ) - return out_tuple + return tuple(get_const_float(elem) for elem in in_tuple) def simplify(expr): From 9c8ffa2802cfebd3fb51b4669464cbbb75f3e14d Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 6 Jun 2019 11:41:50 -0700 Subject: [PATCH 087/176] Fix x86 depthwise conv2d alter_op_layout (#3264) * Fix x86 depthwise conv2d alter_op_layout * Small fix * Add test case * Fix test * Assert kernel layout * Minor fix * Add get_shape function * Minor change --- .../python/relay/test_pass_alter_op_layout.py | 43 ++++++++++++++++++- topi/python/topi/arm_cpu/depthwise_conv2d.py | 6 +-- topi/python/topi/util.py | 39 +++++++++++++++++ topi/python/topi/x86/conv2d.py | 9 ++-- topi/python/topi/x86/depthwise_conv2d.py | 11 ++++- topi/tests/python/test_topi_util.py | 35 +++++++++++++++ 6 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 topi/tests/python/test_topi_util.py diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 2eea1c4ca87a..7d022ba25570 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test alter op layout pass""" +import tvm from tvm import relay from tvm.relay.op import register_alter_op_layout @@ -513,6 +514,45 @@ def expected(): assert alpha_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_depthwise_conv2d(): + """Test depthwise_conv2d operator""" + def before(): + x = relay.var("x", shape=(1, 32, 56, 56)) + w = relay.var("w", shape=(32, 1, 3, 3)) + y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32) + y = relay.Function(free_vars(y), y) + return y + + import topi + @register_alter_op_layout("nn.conv2d", level=110) + def alter_conv2d(attrs, inputs, tinfos): + with tvm.target.create("llvm"): + return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay) + + def expected(): + x = relay.var("x", shape=(1, 32, 56, 56)) + w = relay.var("w", shape=(32, 1, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NCHW8c") + w = relay.layout_transform(w, "OIHW", "OIHW1i8o") + y = relay.nn.contrib_depthwise_conv2d_nchwc(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), + groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o", + out_layout="NCHW8c") + y = relay.layout_transform(y, "NCHW8c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + def test_alter_layout_prelu(): """Test PRelu operator""" def before(): @@ -524,7 +564,7 @@ def before(): y = relay.Function(free_vars(y), y) return y - @register_alter_op_layout("nn.conv2d", level=110) + @register_alter_op_layout("nn.conv2d", level=111) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) @@ -571,4 +611,5 @@ def expected(): test_alter_layout_concatenate() test_alter_layout_nchw_upsamping_op() test_alter_layout_strided_slice() + test_alter_layout_depthwise_conv2d() test_alter_layout_prelu() diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py index e09e355ad8aa..51088df905ca 100644 --- a/topi/python/topi/arm_cpu/depthwise_conv2d.py +++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py @@ -26,11 +26,11 @@ from ..nn.util import get_pad_tuple # register original implementation of depthwise_conv2d_nchw since we don't need to change this part -autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct', +autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct', depthwise_conv2d_nchw.fdefault) # register customized schedule for arm cpu. -@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], +@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', ['direct', 'contrib_spatial_pack']) def schedule_depthwise_conv2d_nchw_arm(cfg, outs): """Schedule depthwise conv2d @@ -151,7 +151,7 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s -@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack']) +@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack']) def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype): """TOPI compute callback for depthwise_conv2d nchw diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 623c81a07da8..d4e23be47e58 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -20,6 +20,7 @@ from numbers import Integral import tvm +from tvm.api import layout, bijective_layout from . import tag def traverse_inline(s, final_op, callback): @@ -289,3 +290,41 @@ def get_max_power2_factor(n, max_value=None): x *= 2 n /= 2 return x + + +def get_shape(src_shape, src_layout, dst_layout): + """Given a source shape, a source layout and a destination layout, infer + the destination shape. + + Parameter + --------- + src_shape : tuple of int or IntImm + Source shape + + src_layout : str or Layout + Source layout + + dst_layout : str or Layout + Destination layout + + Returns + ------- + dst_shape : tuple of int + Destination shape + """ + if src_layout == dst_layout: + return get_const_tuple(src_shape) + + if isinstance(src_layout, str): + src_layout = layout(src_layout) + if isinstance(dst_layout, str): + dst_layout = layout(dst_layout) + + assert len(src_layout) == len(dst_layout), \ + "Incompatible layout %s vs %s" % (src_layout, dst_layout) + + layout_mapping = bijective_layout(src_layout, dst_layout) + dst_indices = layout_mapping.forward_index( + tvm.convert([i for i in range(len(src_layout))])) + + return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index d9831c8347f3..08becf428c27 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -26,7 +26,7 @@ from tvm.autotvm.task import get_config from .. import generic, tag from .. import nn -from ..util import get_const_tuple +from ..util import get_const_tuple, get_shape from ..nn.conv2d import conv2d, conv2d_NCHWc, \ conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload @@ -415,11 +415,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): dtype = data.dtype out_dtype = dtype if out_dtype in ("same", "") else out_dtype - is_depthwise = groups == in_channel and groups == out_channel + + kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW") + is_depthwise = groups == kshape[0] and kshape[1] == 1 # only optimize for NCHW - if layout != 'NCHW': + if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW": return None + if groups != 1 and not is_depthwise: return None diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 6ea11f234759..ddcd8415df77 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -22,11 +22,12 @@ from tvm.autotvm.task.space import SplitEntity from tvm.autotvm.task.topi_integration import deserialize_args from .. import generic, tag +from ..generic import schedule_depthwise_conv2d_nchw from ..nn.pad import pad from ..util import get_const_tuple from ..nn.util import get_pad_tuple -from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \ - depthwise_conv2d_infer_layout +from ..nn.depthwise_conv2d import depthwise_conv2d_nchw, depthwise_conv2d_NCHWc, \ + _get_workload, depthwise_conv2d_infer_layout from .util import get_fp32_len @@ -70,6 +71,12 @@ def _fallback_schedule(cfg, wkl): cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) +autotvm.register_topi_compute(depthwise_conv2d_nchw, 'cpu', 'direct', + depthwise_conv2d_nchw.fdefault) +autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'cpu', 'direct', + schedule_depthwise_conv2d_nchw.fdefault) + + @autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct') def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype=None): diff --git a/topi/tests/python/test_topi_util.py b/topi/tests/python/test_topi_util.py new file mode 100644 index 000000000000..534b6993d411 --- /dev/null +++ b/topi/tests/python/test_topi_util.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for util""" + +import topi + + +def verify_get_shape(src_shape, src_layout, dst_layout, expect_shape): + dst_shape = topi.util.get_shape(src_shape, src_layout, dst_layout) + assert dst_shape == expect_shape, \ + "Shape mismatch: expecting %s but got %s" % (expect_shape, dst_shape) + + +def test_get_shape(): + verify_get_shape((1, 3, 224, 224), "NCHW", "NCHW", (1, 3, 224, 224)) + verify_get_shape((1, 3, 224, 224), "NCHW", "NHWC", (1, 224, 224, 3)) + verify_get_shape((3, 2, 32, 48, 16), "NCHW16c", "NC16cWH", (3, 2, 16, 48, 32)) + verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16)) + +if __name__ == "__main__": + test_get_shape() \ No newline at end of file From ca5a8172334c46234e0a5453952a2e2873a1a508 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Thu, 6 Jun 2019 15:14:11 -0700 Subject: [PATCH 088/176] Minor improve to assertion (#3295) --- nnvm/src/core/graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 29149f48fdb0..e2d6d36020f1 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -107,7 +107,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { for (const auto& nptr : n->control_deps) { if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; auto it = node2index_.find(nptr.get()); - CHECK(it != node2index_.end() && it->first == nptr.get()); + CHECK(it != node2index_.end()) << "control dep not found in graph"; control_deps_.push_back(it->second); } control_rptr.push_back(control_deps_.size()); From 1c8b5c3a1a042d357190b37afe70f6f443ea4954 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Thu, 6 Jun 2019 23:52:07 -0700 Subject: [PATCH 089/176] [VTA] add doc to tsim-example driver and update verilator env variable (#3302) * add documentation and check for extension * add env variable for verilator include * fix typo * this will test if path exist otherwise it won't buid * check if verilator path and binary is set properly * add ? * remove export * no longer needed --- vta/apps/tsim_example/python/tsim/driver.py | 21 ++++++++++++++++---- vta/hardware/chisel/Makefile | 22 +++++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/vta/apps/tsim_example/python/tsim/driver.py b/vta/apps/tsim_example/python/tsim/driver.py index 997d9d527bfe..c388b99cbec9 100644 --- a/vta/apps/tsim_example/python/tsim/driver.py +++ b/vta/apps/tsim_example/python/tsim/driver.py @@ -21,14 +21,27 @@ import os.path as osp from sys import platform -def driver(hw, sw): +def driver(hw_lib, sw_lib): + """Init hardware and software shared library for add-by-one accelerator + + Parameters + ------------ + hw_lib : str + Name of hardware shared library + + sw_lib : str + Name of software shared library + """ _cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) _root_path = osp.join(_cur_path, "..", "..") _cfg_file = osp.join(_root_path, "config", "config.json") _cfg = json.load(open(_cfg_file)) - _ext = ".dylib" if platform == "darwin" else ".so" - _hw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], hw + _ext) - _sw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], sw + _ext) + if not hw_lib.endswith(("dylib", "so")): + hw_lib += ".dylib" if platform == "darwin" else ".so" + if not sw_lib.endswith(("dylib", "so")): + sw_lib += ".dylib" if platform == "darwin" else ".so" + _hw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], hw_lib) + _sw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], sw_lib) def load_dll(dll): try: diff --git a/vta/hardware/chisel/Makefile b/vta/hardware/chisel/Makefile index 7371dd1b3686..91e40a022337 100644 --- a/vta/hardware/chisel/Makefile +++ b/vta/hardware/chisel/Makefile @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. +# Change this variable if Verilator is installed on a different location +VERILATOR_INC_DIR ?= /usr/local/share/verilator/include + +ifeq (, $(shell which verilator)) + $(error "No Verilator in $(PATH), consider doing apt-get install verilator") +endif + +ifeq (, $(wildcard $(VERILATOR_INC_DIR)/*)) + $(error "Verilator include directory is not set properly") +endif + CONFIG = DefaultF1Config TOP = VTA TOP_TEST = Test @@ -25,7 +36,6 @@ VTA_LIBNAME = libvta_hw config_test = $(TOP_TEST)$(CONFIG) vta_dir = $(abspath ../../) tvm_dir = $(abspath ../../../) -verilator_inc_dir = /usr/local/share/verilator/include verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel @@ -50,14 +60,14 @@ cxx_flags += -DVM_SC=0 cxx_flags += -Wno-sign-compare cxx_flags += -include V$(TOP_TEST).h cxx_flags += -I$(verilator_build_dir) -cxx_flags += -I$(verilator_inc_dir) -cxx_flags += -I$(verilator_inc_dir)/vltstd +cxx_flags += -I$(VERILATOR_INC_DIR) +cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd cxx_flags += -I$(vta_dir)/include cxx_flags += -I$(tvm_dir)/include cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include -cxx_files = $(verilator_inc_dir)/verilated.cpp -cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp +cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp +cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp cxx_files += $(wildcard $(verilator_build_dir)/*.cpp) cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc @@ -65,7 +75,7 @@ ifneq ($(USE_TRACE), 0) verilator_opt += --trace cxx_flags += -DVM_TRACE=1 cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd - cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp + cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp else cxx_flags += -DVM_TRACE=0 endif From 7063b77dca38d12488316b979a90b616a431c34d Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 7 Jun 2019 00:01:01 -0700 Subject: [PATCH 090/176] Fix some typos in api docs (#3309) --- include/tvm/relay/error.h | 2 +- include/tvm/runtime/c_runtime_api.h | 2 +- nnvm/include/nnvm/c_api.h | 2 +- nnvm/python/nnvm/frontend/common.py | 6 +++--- nnvm/python/nnvm/frontend/tensorflow.py | 2 +- python/tvm/relay/frontend/common.py | 6 +++--- python/tvm/relay/frontend/tensorflow.py | 12 ++++++------ python/tvm/relay/op/nn/nn.py | 14 +++++++------- src/common/socket.h | 2 +- src/pass/arg_binder.h | 2 +- topi/python/topi/cuda/reduction.py | 2 +- 11 files changed, 26 insertions(+), 26 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 6b9a1fa7b7c6..5189fd982d37 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -83,7 +83,7 @@ struct Error : public dmlc::Error { * * The final mode represents the old mode, if we report an error that has no span or * expression, we will default to throwing an exception with a textual representation - * of the error and no indication of where it occured in the original program. + * of the error and no indication of where it occurred in the original program. * * The latter mode is not ideal, and the goal of the new error reporting machinery is * to avoid ever reporting errors in this style. diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index fd1b877f6d4c..ba2c0d2291b6 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -187,7 +187,7 @@ TVM_DLL void TVMAPISetLastError(const char* msg); /*! * \brief return str message of the last error * all function in this file will return 0 when success - * and -1 when an error occured, + * and -1 when an error occurred, * TVMGetLastError can be called to retrieve the error * * this function is threadsafe and can be called by different thread diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index 75054e892d8e..773bc63b7dad 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -60,7 +60,7 @@ NNVM_DLL void NNAPISetLastError(const char* msg); /*! * \brief return str message of the last error * all function in this file will return 0 when success - * and -1 when an error occured, + * and -1 when an error occurred, * NNGetLastError can be called to retrieve the error * * this function is threadsafe and can be called by different thread diff --git a/nnvm/python/nnvm/frontend/common.py b/nnvm/python/nnvm/frontend/common.py index 610546d1973b..0e09a2c43323 100644 --- a/nnvm/python/nnvm/frontend/common.py +++ b/nnvm/python/nnvm/frontend/common.py @@ -58,7 +58,7 @@ def __call__(self, inputs, attrs, *args): class AttrConverter(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -72,12 +72,12 @@ class AttrConverter(object): `op_name = func(attr)` transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. - If default_value if provded, then the attribute is considered as optional. + If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. excludes : list A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occured. + Raise NotImplementedError if occurred. disables : list A list of attributes that is disabled in nnvm. Log warnings. ignores : list diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 2f91cad8143a..244b48eb3d5a 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1197,7 +1197,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): -> All Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Graph should be frozen with add_shapes=True. - Or user can pass input shape dictionaly optionally. + Or user can pass input shape dictionary optionally. -> DecodeJpeg, ResizeBilinear: These are dummy operators. Hence user should handle preprocessing outside. -> CheckNumerics: No implementation as of now for this. diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 23477626b63b..efd198803c2b 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -286,7 +286,7 @@ def clear_padding(self): class AttrCvt(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -300,12 +300,12 @@ class AttrCvt(object): `op_name = func(attr)` transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. - If default_value if provded, then the attribute is considered as optional. + If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. excludes : list A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occured. + Raise NotImplementedError if occurred. disables : list A list of attributes that is disabled in relay. Log warnings. ignores : list diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f709a63e79e8..45ae2cd19cd1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -77,12 +77,12 @@ class AttrCvt(object): `op_name = func(attr)` transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. - If default_value if provded, then the attribute is considered as optional. + If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. excludes : list A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occured. + Raise NotImplementedError if occurred. disables : list A list of attributes that is disabled in relay. Log warnings. ignores : list @@ -1567,7 +1567,7 @@ def _in_while_loop(control_flow_node_map, op_name): Parameters ---------- control_flow_node_map : Dict[str, Set[str]] - A dictionay contains the unqiue control flow execution frame name to + A dictionay contains the unique control flow execution frame name to a set of primitive operators mapping. op_name : str @@ -1619,7 +1619,7 @@ def f2(): return tf.add(4, 23) r = tf.cond(tf.less(i, j), f1, f2) - This condition statement should be coverted into Relay in the following + This condition statement should be converted into Relay in the following form: .. code-block:: python @@ -1727,7 +1727,7 @@ def __init__(self): self._loop = None def _while_loop(self): - """An internal API to create a Relay recurisve call for a matched TF + """An internal API to create a Relay recursive call for a matched TF `while_loop` construct. """ wl = tvm.relay.var('while_loop') @@ -1796,7 +1796,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): -> All Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Graph should be frozen with add_shapes=True. - Or user can pass input shape dictionaly optionally. + Or user can pass input shape dictionary optionally. -> DecodeJpeg, ResizeBilinear: These are dummy operators. Hence user should handle preprocessing outside. -> CheckNumerics: No implementation as of now for this. diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b4ebffb355d0..7bce9dd3c5b9 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -67,7 +67,7 @@ def conv2d(data, The weight expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -129,7 +129,7 @@ def conv2d_transpose(data, The weight expressions. strides : Tuple[int], optional - The strides of convoltution. + The strides of convolution. padding : Tuple[int], optional The padding of convolution on both sides of inputs. @@ -842,7 +842,7 @@ def contrib_conv2d_winograd_without_weight_transform(data, The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -908,7 +908,7 @@ def contrib_conv2d_winograd_nnpack_without_weight_transform(data, The weight expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -975,7 +975,7 @@ def contrib_conv2d_nchwc(data, The kernel expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -1040,7 +1040,7 @@ def contrib_depthwise_conv2d_nchwc(data, The kernel expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -1156,7 +1156,7 @@ def deformable_conv2d(data, The weight expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. diff --git a/src/common/socket.h b/src/common/socket.h index 58705f16bf73..91f9f4e5cf0a 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -373,7 +373,7 @@ class TCPSocket : public Socket { } /*! * \brief decide whether the socket is at OOB mark - * \return 1 if at mark, 0 if not, -1 if an error occured + * \return 1 if at mark, 0 if not, -1 if an error occurred */ int AtMark() const { #ifdef _WIN32 diff --git a/src/pass/arg_binder.h b/src/pass/arg_binder.h index 9de3a13270dc..f235ea49faac 100644 --- a/src/pass/arg_binder.h +++ b/src/pass/arg_binder.h @@ -50,7 +50,7 @@ namespace ir { * - assert bufferB.shape[1] == n + 3 * * In general, this is a constraint solving problem. We have simplified assumption - * over the binding declaration, such that we require the variable occured in + * over the binding declaration, such that we require the variable occurred in * constraint must be declared in argument list. So it is illegal to have signature * f(tA(shape=(n+3))) without any argument variable corresponds to n, even though * it is already enough to derive n from the input argument. diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py index ff7232cc0fac..25885315179c 100644 --- a/topi/python/topi/cuda/reduction.py +++ b/topi/python/topi/cuda/reduction.py @@ -37,7 +37,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): num_thread = 32 target = tvm.target.current_target() if target and target.target_name == "opencl": - # without it, CL_INVALID_WORK_GROUP_SIZE occured when running test_topi_reduce.py + # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # don't know why num_thread = 16 block_x = tvm.thread_axis("blockIdx.x") From a2864a807fff505b64135b42f688193a27d2c323 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Fri, 7 Jun 2019 17:07:03 +0100 Subject: [PATCH 091/176] [DOC] Capitalize TVM consistently (#3316) --- tutorials/autotvm/tune_simple_template.py | 10 +++++----- tutorials/language/schedule_primitives.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tutorials/autotvm/tune_simple_template.py b/tutorials/autotvm/tune_simple_template.py index c7eea7f42c0b..0a7b9f2dd816 100644 --- a/tutorials/autotvm/tune_simple_template.py +++ b/tutorials/autotvm/tune_simple_template.py @@ -19,19 +19,19 @@ ============================================= **Author**: `Lianmin Zheng `_ -This is an introduction tutorial to the auto-tuning module in tvm. +This is an introduction tutorial to the auto-tuning module in TVM. There are two steps in auto-tuning. The first step is defining a search space. The second step is running a search algorithm to explore through this space. -In this tutorial, you can learn how to perform these two steps in tvm. +In this tutorial, you can learn how to perform these two steps in TVM. The whole workflow is illustrated by a matrix multiplication example. """ ###################################################################### # Install dependencies # -------------------- -# To use autotvm package in tvm, we need to install some extra dependencies. +# To use autotvm package in TVM, we need to install some extra dependencies. # (change "3" to "2" if you use python2): # # .. code-block:: bash @@ -65,7 +65,7 @@ # tunable schedule template. You can regard the process of search space definition # as the parameterization of our existing schedule code. # -# To begin with, here is how we implement a blocked matrix multiplication in tvm. +# To begin with, here is how we implement a blocked matrix multiplication in TVM. # Matmul V0: Constant tiling factor def matmul_v0(N, L, M, dtype): @@ -236,7 +236,7 @@ def matmul(N, L, M, dtype): # In step 1, we build the search space by extending our old schedule code # into a template. The next step is to pick a tuner and explore in this space. # -# Auto-tuners in tvm +# Auto-tuners in TVM # ^^^^^^^^^^^^^^^^^^ # The job for a tuner can be described by following pseudo code # diff --git a/tutorials/language/schedule_primitives.py b/tutorials/language/schedule_primitives.py index 44283818edf4..47ef7173c632 100644 --- a/tutorials/language/schedule_primitives.py +++ b/tutorials/language/schedule_primitives.py @@ -144,7 +144,7 @@ ###################################################################### # compute_at # ---------- -# For a schedule consists of multiple operators, tvm will compute +# For a schedule consists of multiple operators, TVM will compute # tensors at the root separately by default. A = tvm.placeholder((m,), name='A') B = tvm.compute((m,), lambda i: A[i]+1, name='B') From 8a08651829493948ed9ff6577da0f8fc0bf6fb35 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Fri, 7 Jun 2019 17:07:36 +0100 Subject: [PATCH 092/176] [LINT] Improve robustness in task_lint.sh logic (#3315) The existing RAT ASF license auditing logic ignores any failure in the shell pipeline rather than just the exit code of the final grep. Adjust the logic such that failure of the various tools in the pipeline are not elided away. --- tests/scripts/task_lint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index e4b20a2f4b40..5116a42afb93 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -31,7 +31,7 @@ echo "Check file types..." python3 tests/lint/check_file_type.py echo "Check ASF license header..." -java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . |grep "== File" > /tmp/$$.apache-rat.txt || true +java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true) if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then echo "Need to add ASF header to the following files." echo "----------------File List----------------" From a126e3b66e4575e7da892966440cae5a7dfcfe7b Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Fri, 7 Jun 2019 20:51:14 +0100 Subject: [PATCH 093/176] [DOC] minor language use improvements (#3317) --- docs/dev/codebase_walkthrough.rst | 14 +++++++------- tutorials/language/schedule_primitives.py | 2 +- tutorials/tensor_expr_get_started.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 788f1f8b50a3..6aa175c3f114 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -43,7 +43,7 @@ When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compile - Generate a compute expression and a schedule for the operator - Compile the operator into object code -One of the interesting aspects of TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. +One of the interesting aspects of TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that does heavy lifting is implemented in C++, and Python bindings are provided for the user interface. This is also true in TVM, but in TVM codebase, C++ code can also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. ******************************************* Vector Add Example @@ -84,7 +84,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ args[4]); }); -We use ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of `PackedFunc `_. ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy. +We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc `_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy. A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``. @@ -141,7 +141,7 @@ Bound inference is the process where all loop bounds and sizes of intermediate b .. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html -``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects that changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. +``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below. @@ -173,7 +173,7 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ } -``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: +The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: :: @@ -182,9 +182,9 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ *rv = BuildCUDA(args[0]); }); -``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. +The ``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. -``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. +The ``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages the CUDA driver API. The ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. The returned module, which can be thought of as a combination of a compiled function and a device API, can be invoked on TVM's NDArray objects. @@ -243,4 +243,4 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal } }; -This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, at the end all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase. +This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, in the end, all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase. diff --git a/tutorials/language/schedule_primitives.py b/tutorials/language/schedule_primitives.py index 47ef7173c632..e59264f29898 100644 --- a/tutorials/language/schedule_primitives.py +++ b/tutorials/language/schedule_primitives.py @@ -144,7 +144,7 @@ ###################################################################### # compute_at # ---------- -# For a schedule consists of multiple operators, TVM will compute +# For a schedule that consists of multiple operators, TVM will compute # tensors at the root separately by default. A = tvm.placeholder((m,), name='A') B = tvm.compute((m,), lambda i: A[i]+1, name='B') diff --git a/tutorials/tensor_expr_get_started.py b/tutorials/tensor_expr_get_started.py index cdd07d466a37..a0b84f0e81ca 100644 --- a/tutorials/tensor_expr_get_started.py +++ b/tutorials/tensor_expr_get_started.py @@ -108,7 +108,7 @@ ###################################################################### # Finally we bind the iteration axis bx and tx to threads in the GPU -# compute grid. These are GPU specific constructs that allows us +# compute grid. These are GPU specific constructs that allow us # to generate code that runs on GPU. # if tgt == "cuda" or tgt.startswith('opencl'): From c68d5c57ff366206d79f73cc774d99beb5a099d7 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Fri, 7 Jun 2019 20:52:00 +0100 Subject: [PATCH 094/176] [CI] Ensure rat ignores rust cargo lock files [CI] Ensure rat ignores emacs backup files [CI] Ensure rat ignores .egg-info (#3314) --- tests/lint/rat-excludes | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index f449c5ee68b9..72faa1112c94 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -25,6 +25,7 @@ .*\.log # Generated modules +.*\.egg-info .*gen_modules/* .*doxygen core.cpp @@ -32,6 +33,7 @@ build _static _build .*~ +\#..*\# # Specific files package-list @@ -44,3 +46,4 @@ rat-excludes __init__.py pylintrc config.cmake +Cargo.lock From a6083b05223ff0f11593930b253ccaf943337c2f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 7 Jun 2019 14:38:57 -0700 Subject: [PATCH 095/176] [PASS][RELAY] polish pass infra (#3319) --- 3rdparty/dmlc-core | 2 +- src/relay/pass/pass_manager.cc | 168 +++++++++------------------------ 2 files changed, 45 insertions(+), 125 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 3943914eed66..fbe142b267a8 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f +Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661 diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 13e908d28f7a..05eb43d6a653 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -37,47 +37,6 @@ namespace transform { using tvm::IRPrinter; -namespace { - -// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be -// handled because we need to register the pass for Python invocation anyway. -Pass GetPass(const std::string& pass_name) { - if (pass_name == "InferType") { - return InferType(); - } else if (pass_name == "AlterOpLayout") { - return AlterOpLayout(); - } else if (pass_name == "CanonicalizeOps") { - return CanonicalizeOps(); - } else if (pass_name == "CombineParallelConv2d") { - return CombineParallelConv2D(); - } else if (pass_name == "DeadCodeElimination") { - return DeadCodeElimination(); - } else if (pass_name == "EliminateCommonSubexpr") { - return DeadCodeElimination(); - } else if (pass_name == "FoldConstant") { - return FoldConstant(); - } else if (pass_name == "BackwardFoldScaleAxis") { - return FoldScaleAxis(); - } else if (pass_name == "ForwardFoldScaleAxis") { - return FoldScaleAxis(); - } else if (pass_name == "FoldScaleAxis") { - return FoldScaleAxis(); - } else if (pass_name == "PartialEvaluate") { - return SimplifyInference(); - } else if (pass_name == "SimplifyInference") { - return SimplifyInference(); - } else if (pass_name == "ToANormalForm") { - return ToANormalForm(); - } else if (pass_name == "ToGraphNormalForm") { - return ToGraphNormalForm(); - } else { - LOG(FATAL) << pass_name << " has not been registered yet." << "\n"; - return Pass(nullptr); - } -} - -} // namespace - struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; @@ -252,6 +211,7 @@ class SequentialNode : public PassNode { /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); @@ -262,23 +222,14 @@ class SequentialNode : public PassNode { */ PassInfo Info() const { return pass_info; } - /*! - * \brief Add a pass to the pass list. - * - * \param pass The candidate pass to be added. - */ - void AddPass(const Pass& pass) { - passes.push_back(pass); - } - /*! * \brief Check if a pass is enabled. * - * \param pass_name The name of an optimization/analysis pass. + * \param info The pass information. * * \return true if the pass is enabled. Otherwise, false. */ - bool PassEnabled(const std::string& pass_name) const; + bool PassEnabled(const PassInfo& info) const; /*! * \brief Resolve the pass dependency. It globs all required passes by @@ -294,12 +245,6 @@ class SequentialNode : public PassNode { */ void ResolveDependency(const Module& mod); - std::unordered_set DisabledPasses( - const Array& disabled) const; - - std::unordered_set RequiredPasses( - const Array& required) const; - /*! * \brief Perform optimizations on a series of passes. The aforementioned * typical pass manager jobs could be done by it. This function could @@ -317,7 +262,8 @@ class SequentialNode : public PassNode { TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); }; -PassInfo PassInfoNode::make(int opt_level, std::string name, +PassInfo PassInfoNode::make(int opt_level, + std::string name, tvm::Array required) { auto pass_info = make_node(); pass_info->opt_level = opt_level; @@ -338,23 +284,13 @@ ModulePass ModulePassNode::make( // Module -> Module optimizations. Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { - PassInfo pass_info = Info(); - DLOG(INFO) << "Executing module pass : " << pass_info->name - << " with opt level: " << pass_info->opt_level << "\n"; - + const PassInfo& pass_info = Info(); + DLOG(INFO) << "Executing module pass : " + << pass_info->name + << " with opt level: " + << pass_info->opt_level; CHECK(mod.defined()); - Module updated_mod = mod; - // Execute the required passes in a DFS way. - // TODO(zhiics) We may need to pass validation to detect the cyclic - // dependency. - for (const auto& it : pass_info->required) { - const auto* name = it.as(); - CHECK(name); - auto pass = GetPass(name->value); - updated_mod = pass(updated_mod, pass_ctx); - } - - updated_mod = pass_func(updated_mod, pass_ctx); + Module updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } @@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make( } // Perform Module -> Module optimizations at the Function level. -// TODO(zhiics) Check and handle the required passes. Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { - PassInfo pass_info = Info(); + const PassInfo& pass_info = Info(); CHECK(mod.defined()); - DLOG(INFO) << "Executing module pass : " << pass_info->name - << " with opt level: " << pass_info->opt_level << "\n"; - + DLOG(INFO) << "Executing module pass : " + << pass_info->name + << " with opt level: " + << pass_info->opt_level; Module updated_mod = mod; - // Execute the required passes in a DFS way. - // TODO(zhiics) We may need to pass validation to detect the cyclic - // dependency. - for (const auto& it : pass_info->required) { - const auto* name = it.as(); - CHECK(name); - auto pass = GetPass(name->value); - updated_mod = pass(updated_mod, pass_ctx); - } - Module new_mod = ModuleNode::make({}, mod->type_definitions); // Execute the pass function and return a new module. for (const auto& it : mod->functions) { @@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod, : pass_func(it.second, updated_mod, pass_ctx); new_mod->Add(it.first, updated_func); } - return new_mod; } @@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) { << "\n"; } -std::unordered_set SequentialNode::DisabledPasses( - const Array& disabled) const { - std::unordered_set ret; - for (const auto& it : disabled) { - const auto* str = it.as(); - CHECK(str) << "Disabled pass name must be string."; - ret.emplace(str->value); - } - return ret; -} - -std::unordered_set SequentialNode::RequiredPasses( - const Array& required) const { - std::unordered_set ret; - for (const auto& it : required) { - const auto* str = it.as(); - CHECK(str) << "Required pass name must be string."; - ret.emplace(str->value); +// linearly scan the pass array to match pass_name +inline bool PassArrayContains(const Array& pass_array, + const std::string& pass_name) { + for (auto x : pass_array) { + auto* str_name = x.as(); + CHECK(str_name) << "pass name must be str"; + if (str_name->value == pass_name) return true; } - return ret; + return false; } -bool SequentialNode::PassEnabled(const std::string& pass_name) const { +bool SequentialNode::PassEnabled(const PassInfo& info) const { PassContext ctx = PassContext::Current(); - auto required = RequiredPasses(ctx->required_pass); - auto disabled = DisabledPasses(ctx->disabled_pass); - - if (disabled.count(pass_name)) { + if (PassArrayContains(ctx->disabled_pass, info->name)) { return false; } - if (required.count(pass_name)) { + if (PassArrayContains(ctx->required_pass, info->name)) { return true; } - const Pass pass = GetPass(pass_name); - PassInfo info = pass->Info(); return ctx->opt_level >= info->opt_level; } +Pass GetPass(const std::string& pass_name) { + using tvm::runtime::Registry; + std::string fpass_name = "relay._transform." + pass_name; + const auto* f = Registry::Get(fpass_name); + CHECK(f != nullptr) << "Cannot find " << fpass_name + << "to create the pass " << pass_name; + return (*f)(); +} + // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. @@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module, Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; - - PassInfo info = pass->Info(); - const auto& pass_name = info->name; - // Execute the pass if it is enabled. - if (PassEnabled(pass_name)) { - mod = pass(mod, pass_ctx); + const PassInfo& pass_info = pass->Info(); + if (!PassEnabled(pass_info)) continue; + // resolve dependencies + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + mod = GetPass(name->value)(mod, pass_ctx); } + mod = pass(mod, pass_ctx); } return mod; } From c34dddd3bc5c8751c9889f0665630faf46fa5454 Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Sat, 8 Jun 2019 12:17:29 -0400 Subject: [PATCH 096/176] Make the behavior of data nullptr check of pooling layer same as others. (#3322) --- 3rdparty/dmlc-core | 2 +- src/relay/op/nn/pooling.cc | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index fbe142b267a8..3943914eed66 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661 +Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 4dd763b45654..44c9f89aa9e7 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -70,7 +70,8 @@ bool Pool2DRel(const Array& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) return false; + const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) << "Pool2D only support input >= 2-D: input must have height and width"; From ff4ffe9c679a8bb1ada2c608320e84ffed3ce4a5 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Sat, 8 Jun 2019 15:00:31 -0700 Subject: [PATCH 097/176] [VTA] [APPS] [TSIM] update documentation (README) (#3318) * update README * update README * update README * update README * fix typo --- vta/apps/tsim_example/README.md | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vta/apps/tsim_example/README.md b/vta/apps/tsim_example/README.md index e19d6bbe7ef0..8f1230e9ba7e 100644 --- a/vta/apps/tsim_example/README.md +++ b/vta/apps/tsim_example/README.md @@ -49,22 +49,29 @@ sudo apt install verilator sbt ## Setup in TVM 1. Install `verilator` and `sbt` as described above -2. Build tvm +2. Change `TARGET` to `tsim` in `/tvm/vta/config/vta_config.json` +3. Build [tvm](https://docs.tvm.ai/install/from_source.html#build-the-shared-library) ## How to run VTA TSIM examples There are two sample VTA accelerators (add-by-one) designed in Chisel3 and Verilog to show how *TSIM* works. -These examples are located at `/vta/apps/tsim_example`. +The default `TARGET` language for these two implementations is Verilog. The following instructions show +how to run both of them: -* Instructions +* Verilog add-by-one + * Go to `/vta/apps/tsim_example` + * Run `make` to build and run add-by-one test + +* Chisel3 add-by-one * Open `/vta/apps/tsim_example/python/tsim/config.json` - * Change `TARGET` from `verilog` to `chisel`, depending on what language backend you would like to test + * Change `TARGET` from `verilog` to `chisel` * Go to `tvm/vta/apps/tsim_example` - * Run `make` + * Run `make` to build and run add-by-one test * Some pointers - * Build cmake script for software library`/vta/apps/tsim_example/cmake/modules/sw.cmake` - * Build cmake script for hardware library`/vta/apps/tsim_example/cmake/modules/hw.cmake` - * Software driver that handles the accelerator `/vta/apps/tsim_example/src/driver.cc` + * Add-by-one test `/vta/apps/tsim_example/tests/python/add_by_one.py` * Add-by-one accelerator in Verilog `/vta/apps/tsim_example/hardware/verilog` * Add-by-one accelerator in Chisel3 `/vta/apps/tsim_example/hardware/chisel` + * Software driver that handles the accelerator `/vta/apps/tsim_example/src/driver.cc` + * Build cmake script for software library`/vta/apps/tsim_example/cmake/modules/sw.cmake` + * Build cmake script for hardware library`/vta/apps/tsim_example/cmake/modules/hw.cmake` From dc8d27e27c86ceba8918a62f741460bfb9fc0704 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sat, 8 Jun 2019 20:56:58 -0700 Subject: [PATCH 098/176] [Rust] Static syslib (#3274) --- rust/Cargo.toml | 1 + rust/macros/Cargo.toml | 36 ++++++ rust/macros/src/lib.rs | 122 ++++++++++++++++++ rust/runtime/Cargo.toml | 3 +- rust/runtime/src/graph.rs | 6 +- rust/runtime/src/lib.rs | 2 +- rust/runtime/tests/test_tvm_basic/build.rs | 20 ++- rust/runtime/tests/test_tvm_basic/src/main.rs | 19 ++- 8 files changed, 193 insertions(+), 16 deletions(-) create mode 100644 rust/macros/Cargo.toml create mode 100644 rust/macros/src/lib.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6e89bae5c6f2..02e2c7c67c99 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -18,6 +18,7 @@ [workspace] members = [ "common", + "macros", "runtime", "runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_dso", diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml new file mode 100644 index 000000000000..15773b625be9 --- /dev/null +++ b/rust/macros/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-macros" +version = "0.1.0" +license = "Apache-2.0" +description = "Proc macros used by the TVM crates." +repository = "https://github.com/dmlc/tvm" +readme = "README.md" +keywords = ["tvm"] +authors = ["TVM Contributors"] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +goblin = "0.0.22" +proc-macro2 = "0.4" +proc-quote = "0.2" +syn = "0.15" diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs new file mode 100644 index 000000000000..704f7c1de58b --- /dev/null +++ b/rust/macros/src/lib.rs @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![feature(bind_by_move_pattern_guards, proc_macro_span)] + +extern crate proc_macro; + +use std::{fs::File, io::Read}; + +use proc_quote::quote; + +#[proc_macro] +pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let obj_file_path = syn::parse_macro_input!(input as syn::LitStr); + + let mut path = obj_file_path.span().unwrap().source_file().path(); + path.pop(); // remove the filename + path.push(obj_file_path.value()); + + let mut fd = File::open(&path) + .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); + let mut buffer = Vec::new(); + fd.read_to_end(&mut buffer).unwrap(); + + let fn_names = match goblin::Object::parse(&buffer).unwrap() { + goblin::Object::Elf(elf) => elf + .syms + .iter() + .filter_map(|s| { + if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { + return None; + } + match elf.strtab.get(s.st_name) { + Some(Ok(name)) if name != "" => { + Some(syn::Ident::new(name, proc_macro2::Span::call_site())) + } + _ => None, + } + }) + .collect::>(), + goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { + obj.symbols() + .filter_map(|s| match s { + Ok((name, nlist)) + if nlist.is_global() + && nlist.n_sect != 0 + && !name.ends_with("tvm_module_ctx") => + { + Some(syn::Ident::new( + if name.starts_with('_') { + // Mach objects prepend a _ to globals. + &name[1..] + } else { + &name + }, + proc_macro2::Span::call_site(), + )) + } + _ => None, + }) + .collect::>() + } + _ => panic!("Unsupported object format."), + }; + + let extern_fns = quote! { + mod ext { + extern "C" { + #( + pub(super) fn #fn_names( + args: *const tvm_runtime::ffi::TVMValue, + type_codes: *const std::os::raw::c_int, + num_args: std::os::raw::c_int + ) -> std::os::raw::c_int; + )* + } + } + }; + + let fns = quote! { + use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; + #extern_fns + + #( + pub fn #fn_names(args: &[TVMArgValue]) -> Result { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = unsafe { + ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) + }; + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) + } + } + )* + }; + + proc_macro::TokenStream::from(fns) +} diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index 5809af0c6c6d..3c81a93c9bbf 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -41,7 +41,8 @@ nom = {version = "4.0.0", default-features = false } serde = "1.0.59" serde_derive = "1.0.79" serde_json = "1.0.17" -tvm-common = { version = "0.1.0", path = "../common/" } +tvm-common = { version = "0.1", path = "../common" } +tvm-macros = { version = "0.1", path = "../macros" } [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index bff02f504a5e..cacd7a38a97f 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -164,7 +164,7 @@ impl<'a> TryFrom<&'a str> for Graph { /// ``` pub struct GraphExecutor<'m, 't> { graph: Graph, - op_execs: Vec>, + op_execs: Vec>, tensors: Vec>, } @@ -240,7 +240,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { graph: &Graph, lib: &'m M, tensors: &Vec>, - ) -> Result>, Error> { + ) -> Result>, Error> { ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); @@ -279,7 +279,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { }) .collect::, Error>>() .unwrap(); - let op: Box = box move || { + let op: Box = box move || { let args = dl_tensors .iter() .map(|t| t.into()) diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index c774d5bbc983..010fbf7d6a29 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -29,7 +29,6 @@ //! For examples of use, please refer to the multi-file tests in the `tests` directory. #![feature( - alloc, allocator_api, box_syntax, fn_traits, @@ -77,6 +76,7 @@ pub use tvm_common::{ packed_func::{self, *}, TVMArgValue, TVMRetValue, }; +pub use tvm_macros::import_module; pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*}; diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs index ea3bfcb85136..3439f9c2efc7 100644 --- a/rust/runtime/tests/test_tvm_basic/build.rs +++ b/rust/runtime/tests/test_tvm_basic/build.rs @@ -19,13 +19,21 @@ extern crate ar; -use std::{env, path::Path, process::Command}; +use std::{path::PathBuf, process::Command}; use ar::Builder; use std::fs::File; fn main() { - let out_dir = env::var("OUT_DIR").unwrap(); + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).unwrap(); + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest.a"); let output = Command::new(concat!( env!("CARGO_MANIFEST_DIR"), @@ -35,7 +43,7 @@ fn main() { .output() .expect("Failed to execute command"); assert!( - Path::new(&format!("{}/test.o", out_dir)).exists(), + obj_file.exists(), "Could not build tvm lib: {}", String::from_utf8(output.stderr) .unwrap() @@ -45,9 +53,9 @@ fn main() { .unwrap_or("") ); - let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap()); - builder.append_path(format!("{}/test.o", out_dir)).unwrap(); + let mut builder = Builder::new(File::create(lib_file).unwrap()); + builder.append_path(obj_file).unwrap(); println!("cargo:rustc-link-lib=static=test"); - println!("cargo:rustc-link-search=native={}", out_dir); + println!("cargo:rustc-link-search=native={}", out_dir.display()); } diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs index 14bb7c20c680..a83078e5834a 100644 --- a/rust/runtime/tests/test_tvm_basic/src/main.rs +++ b/rust/runtime/tests/test_tvm_basic/src/main.rs @@ -22,13 +22,14 @@ extern crate ndarray; extern crate tvm_runtime; use ndarray::Array; -use tvm_runtime::{DLTensor, Module, SystemLibModule}; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +mod tvm_mod { + import_module!("../lib/test.o"); +} fn main() { - let syslib = SystemLibModule::default(); - let add = syslib - .get_function("default_function") - .expect("main function not found"); + // try static let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); let mut c = Array::from_vec(vec![0f32; 4]); @@ -36,6 +37,14 @@ fn main() { let mut a_dl: DLTensor = (&mut a).into(); let mut b_dl: DLTensor = (&mut b).into(); let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(tvm_mod::default_function, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); + + // try runtime + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); assert!(c.all_close(&e, 1e-8f32)); } From bfdda281dcbf73881563321b9b4cdaf47acad2f3 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Sun, 9 Jun 2019 13:34:56 -0700 Subject: [PATCH 099/176] Improve non_max_suppression and get_valid_counts for CPU (#3305) * Improve non_max_suppression for CPU * Improve get_valid_counts * Minor change * Skip some unnecessary computes --- include/tvm/relay/attrs/vision.h | 6 ++ python/tvm/relay/frontend/mxnet.py | 3 +- python/tvm/relay/op/vision/_vision.py | 5 +- python/tvm/relay/op/vision/nms.py | 13 +++- src/relay/op/vision/nms.cc | 6 +- tests/python/relay/test_op_level5.py | 20 +++---- topi/python/topi/cuda/nms.py | 8 ++- topi/python/topi/vision/nms.py | 85 +++++++++++++++++---------- topi/tests/python/test_topi_vision.py | 83 ++++++++++++++++---------- 9 files changed, 152 insertions(+), 77 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 11b4ebfcfaad..7fa1ffb8a4fe 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -79,10 +79,16 @@ struct MultiBoxTransformLocAttrs /*! \brief Attributes used in get_valid_counts operator */ struct GetValidCountsAttrs : public tvm::AttrsNode { double score_threshold; + int id_index; + int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { TVM_ATTR_FIELD(score_threshold).set_default(0.0) .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(id_index).set_default(0) + .describe("Axis index of id."); + TVM_ATTR_FIELD(score_index).set_default(1) + .describe("Index of the scores/confidence of boxes."); } }; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0975a33450c8..81ef51b91336 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -569,7 +569,8 @@ def _mx_box_nms(inputs, attrs): raise tvm.error.OpAttributeInvalid( 'Value of attribute "out_format" must equal "corner" for operator box_nms.') - ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) + ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh, + id_index=id_index, score_index=score_index) nms_out = _op.vision.non_max_suppression(ret[1], ret[0], iou_threshold=iou_thresh, diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 8c8c4cd9aaa3..7de118071aa4 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -82,7 +82,10 @@ def schedule_get_valid_counts(_, outs, target): def compute_get_valid_counts(attrs, inputs, _, target): """Compute definition of get_valid_counts""" score_threshold = get_const_float(attrs.score_threshold) - return topi.vision.get_valid_counts(inputs[0], score_threshold) + id_index = get_const_int(attrs.id_index) + score_index = get_const_int(attrs.score_index) + return topi.vision.get_valid_counts(inputs[0], score_threshold, + id_index, score_index) reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index ab34eb6e6cfb..d19dde306aca 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -20,7 +20,9 @@ from ...expr import TupleWrapper def get_valid_counts(data, - score_threshold): + score_threshold, + id_index=0, + score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -32,6 +34,12 @@ def get_valid_counts(data, score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- valid_count : relay.Expr @@ -40,7 +48,8 @@ def get_valid_counts(data, out_tensor : relay.Expr Rearranged data tensor. """ - return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2) + return TupleWrapper(_make.get_valid_counts(data, score_threshold, + id_index, score_index), 2) def non_max_suppression(data, diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 2e5661cdc4dc..c0160e7d7128 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -50,9 +50,13 @@ bool GetValidCountRel(const Array& types, } Expr MakeGetValidCounts(Expr data, - double score_threshold) { + double score_threshold, + int id_index, + int score_index) { auto attrs = make_node(); attrs->score_threshold = score_threshold; + attrs->id_index = id_index; + attrs->score_index = score_index; static const Op& op = Op::Get("vision.get_valid_counts"); return CallNode::make(op, {data}, Attrs(attrs), {}); } diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 21b227f6b3b5..3d9ec6dde4ad 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -152,28 +152,28 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), def test_get_valid_counts(): - def verify_get_valid_counts(dshape, score_threshold): + def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" batch_size, num_anchor, elem_length = dshape - np_data = np.random.uniform(size=dshape).astype(dtype) + np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 for j in range(num_anchor): - score = np_data[i, j, 1] - if score >= score_threshold: + score = np_data[i, j, score_index] + if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 inter_idx += 1 if j >= np_out1[i]: for k in range(elem_length): - np_out2[i, j, k] = -1 + np_out2[i, j, k] = -1.0 x = relay.var("x", relay.ty.TensorType(dshape, dtype)) - z = relay.vision.get_valid_counts(x, score_threshold) + z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index) assert "score_threshold" in z.astext() func = relay.Function([x], z.astuple()) func = relay.ir_pass.infer_type(func) @@ -185,10 +185,10 @@ def verify_get_valid_counts(dshape, score_threshold): tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) - verify_get_valid_counts((1, 2500, 6), 0) - verify_get_valid_counts((1, 2500, 6), -1) - verify_get_valid_counts((3, 1000, 6), 0.55) - verify_get_valid_counts((16, 500, 6), 0.95) + verify_get_valid_counts((1, 2500, 6), 0, 0, 1) + verify_get_valid_counts((1, 2500, 5), -1, -1, 0) + verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) + verify_get_valid_counts((16, 500, 5), 0.95, -1, 0) def test_non_max_suppression(): diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 460584bc8b78..c0da4a45ec8d 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -313,7 +313,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): @get_valid_counts.register(["cuda", "gpu"]) -def get_valid_counts_gpu(data, score_threshold=0): +def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -325,6 +325,12 @@ def get_valid_counts_gpu(data, score_threshold=0): score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- valid_count : tvm.Tensor diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index 7c8d7db33059..a6ba56eeb943 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements +# pylint: disable=import-error, invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements, too-many-function-args """Non-maximum suppression operator""" import tvm @@ -60,7 +60,7 @@ def hybrid_rearrange_out(data): @hybrid.script -def hybrid_get_valid_counts(data, score_threshold): +def hybrid_get_valid_counts(data, score_threshold, id_index, score_index): """Hybrid routine to get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -68,11 +68,18 @@ def hybrid_get_valid_counts(data, score_threshold): Parameters ---------- data : tvm.Tensor or numpy NDArray - Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + Input data. 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. score_threshold : tvm.const Lower limit of score for valid bounding boxes. + id_index : tvm.const + index of the class categories, -1 to disable. + + score_index: tvm.const + Index of the scores/confidence of boxes. + Returns ------- out_tensor : tvm.Tensor or numpy NDArray @@ -92,8 +99,9 @@ def hybrid_get_valid_counts(data, score_threshold): for i in parallel(batch_size): valid_count[i] = 0 for j in range(num_anchors): - score = data[i, j, 1] - if score > score_threshold: + score = data[i, j, score_index] + if score > score_threshold and \ + (id_index < 0 or data[i, j, id_index] >= 0): for k in range(box_data_length): out_tensor[i, valid_count[i], k] = data[i, j, k] valid_count[i] += 1 @@ -103,18 +111,25 @@ def hybrid_get_valid_counts(data, score_threshold): return valid_count, out_tensor @tvm.target.generic_func -def get_valid_counts(data, score_threshold=0): +def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. Parameters ---------- data : tvm.Tensor - Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + Input data. 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- out_tensor : tvm.Tensor @@ -123,14 +138,17 @@ def get_valid_counts(data, score_threshold=0): valid_count : tvm.Tensor 1-D tensor for valid number of boxes. """ - score_threshold_const = tvm.const(score_threshold, "float") - return hybrid_get_valid_counts(data, score_threshold_const) + score_threshold_const = tvm.const(score_threshold, "float32") + id_index_const = tvm.const(id_index, "int32") + score_index_const = tvm.const(score_index, "int32") + return hybrid_get_valid_counts(data, score_threshold_const, + id_index_const, score_index_const) @hybrid.script def hybrid_nms(data, sorted_index, valid_count, max_output_size, iou_threshold, force_suppress, - top_k, coord_start, id_index): + top_k, coord_start, id_index, score_index): """Hybrid routing for non-maximum suppression. Parameters @@ -165,6 +183,9 @@ def hybrid_nms(data, sorted_index, valid_count, id_index : tvm.const index of the class categories, -1 to disable. + score_index: tvm.const + Index of the scores/confidence of boxes. + Returns ------- output : tvm.Tensor @@ -182,41 +203,42 @@ def hybrid_nms(data, sorted_index, valid_count, box_data_length,), data.dtype) - for i in parallel(batch_size): + for i in range(batch_size): if iou_threshold > 0: if valid_count[i] > 0: # Reorder output nkeep = valid_count[i] if 0 < top_k < nkeep: nkeep = top_k - for j in range(nkeep): + for j in parallel(nkeep): for k in range(box_data_length): output[i, j, k] = data[i, sorted_index[i, j], k] box_indices[i, j] = sorted_index[i, j] if 0 < top_k < valid_count[i]: - for j in range(valid_count[i] - nkeep): + for j in parallel(valid_count[i] - nkeep): for k in range(box_data_length): output[i, j + nkeep, k] = -1.0 box_indices[i, j + nkeep] = -1 # Apply nms + box_start_idx = coord_start + batch_idx = i for j in range(valid_count[i]): - if output[i, j, 0] >= 0: - for k in range(valid_count[i]): + if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j, id_index] >= 0): + box_a_idx = j + for k in parallel(valid_count[i]): check_iou = 0 - if k > j and output[i, k, 0] >= 0: + if k > j and output[i, k, score_index] > 0 \ + and (id_index < 0 or output[i, k, id_index] >= 0): if force_suppress: check_iou = 1 - elif id_index < 0 or output[i, j, 0] == output[i, k, 0]: + elif id_index < 0 or output[i, j, id_index] == output[i, k, id_index]: check_iou = 1 if check_iou > 0: - batch_idx = i - box_a_idx = j - box_b_idx = k - box_start_idx = coord_start - a_t = output[batch_idx, box_a_idx, box_start_idx + 1] - a_b = output[batch_idx, box_a_idx, box_start_idx + 3] a_l = output[batch_idx, box_a_idx, box_start_idx] + a_t = output[batch_idx, box_a_idx, box_start_idx + 1] a_r = output[batch_idx, box_a_idx, box_start_idx + 2] + a_b = output[batch_idx, box_a_idx, box_start_idx + 3] + box_b_idx = k b_t = output[batch_idx, box_b_idx, box_start_idx + 1] b_b = output[batch_idx, box_b_idx, box_start_idx + 3] b_l = output[batch_idx, box_b_idx, box_start_idx] @@ -227,22 +249,24 @@ def hybrid_nms(data, sorted_index, valid_count, u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area iou = 0.0 if u <= 0.0 else area / u if iou >= iou_threshold: - output[i, k, 0] = -1.0 + output[i, k, score_index] = -1.0 + if id_index >= 0: + output[i, k, id_index] = -1.0 box_indices[i, k] = -1 else: - for j in range(valid_count[i]): + for j in parallel(valid_count[i]): for k in range(box_data_length): output[i, j, k] = data[i, j, k] box_indices[i, j] = j # Set invalid entry to be -1 - for j in range(num_anchors - valid_count[i]): + for j in parallel(num_anchors - valid_count[i]): for k in range(box_data_length): output[i, j + valid_count[i], k] = -1.0 box_indices[i, j + valid_count[i]] = -1 # Only return max_output_size valid boxes num_valid_boxes = 0 if max_output_size > 0: - for j in range(valid_count[i]): + for j in parallel(valid_count[i]): if output[i, j, 0] >= 0: if num_valid_boxes == max_output_size: for k in range(box_data_length): @@ -263,9 +287,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, Parameters ---------- data : tvm.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. - The last dimension should be in format of - [class_id, score, box_left, box_top, box_right, box_bottom]. + 3-D tensor with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. @@ -338,7 +360,8 @@ def non_max_suppression(data, valid_count, max_output_size=-1, tvm.const(force_suppress, dtype="bool"), tvm.const(top_k, dtype="int32"), tvm.const(coord_start, dtype="int32"), - tvm.const(id_index, dtype="int32")) + tvm.const(id_index, dtype="int32"), + tvm.const(score_index, dtype="int32")) if not return_indices and invalid_to_bottom: out = hybrid_rearrange_out(out) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 54c80c6e8c30..3a0b13489037 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -27,18 +27,18 @@ from topi.vision import ssd, non_max_suppression, get_valid_counts -def verify_get_valid_counts(dshape, score_threshold): +def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" batch_size, num_anchor, elem_length = dshape - np_data = np.random.uniform(size=dshape).astype(dtype) + np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 for j in range(num_anchor): - score = np_data[i, j, 1] - if score > score_threshold: + score = np_data[i, j, score_index] + if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 @@ -55,8 +55,8 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): data = tvm.placeholder(dshape, name="data", dtype=dtype) - outs = get_valid_counts(data, score_threshold) - s = topi.generic.schedule_multibox_prior(outs) + outs = get_valid_counts(data, score_threshold, id_index, score_index) + s = topi.generic.schedule_get_valid_counts(outs) tvm_input_data = tvm.nd.array(np_data, ctx) tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) @@ -67,33 +67,26 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) for device in ['llvm', 'cuda', 'opencl']: + # Disable gpu test for now + if device != "llvm": + continue check_device(device) def test_get_valid_counts(): - verify_get_valid_counts((1, 2500, 6), 0) - verify_get_valid_counts((1, 2500, 6), -1) - verify_get_valid_counts((3, 1000, 6), 0.55) - verify_get_valid_counts((16, 500, 6), 0.95) + verify_get_valid_counts((1, 2500, 6), 0, 0, 1) + verify_get_valid_counts((1, 2500, 5), -1, -1, 0) + verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) + verify_get_valid_counts((16, 500, 5), 0.95, -1, 1) -def test_non_max_suppression(): - dshape = (1, 5, 6) - indices_dshape = (1, 5) +def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, iou_threshold, + force_suppress, top_k, coord_start, score_index, id_index): + dshape = np_data.shape + batch, num_anchors, _ = dshape + indices_dshape = (batch, num_anchors) data = tvm.placeholder(dshape, name="data") - valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") - nms_threshold = 0.7 - force_suppress = True - nms_topk = 2 - - np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], - [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], - [1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype) - np_valid_count = np.array([4]).astype(valid_count.dtype) - np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[3, 0, -1, -1, -1]]) + valid_count = tvm.placeholder((batch,), dtype="int32", name="valid_count") def check_device(device): ctx = tvm.context(device, 0) @@ -103,11 +96,17 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): if device == 'llvm': - out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) - indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) + out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index, + return_indices=False) + indices_out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index) else: - out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) - indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) + out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index, + return_indices=False) + indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index) s = topi.generic.schedule_nms(out) indices_s = topi.generic.schedule_nms(indices_out) @@ -128,6 +127,30 @@ def check_device(device): check_device(device) +def test_non_max_suppression(): + np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], + [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], + [1, 0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) + + verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) + + np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80], + [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79], + [0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45], + [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) + verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) + + + def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): data = tvm.placeholder(dshape, name="data") From c9dc9a3993d6b3f9b57977ddcdd23360974d2207 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Sun, 9 Jun 2019 16:24:11 -0700 Subject: [PATCH 100/176] Add MUL operator to relay tflite frontend (#3304) --- python/tvm/relay/frontend/tflite.py | 20 +++++++---- tests/python/frontend/tflite/test_forward.py | 38 ++++++++++++++++---- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3c3808d09712..ad2cd49eb12b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -64,6 +64,7 @@ def __init__(self, model, subgraph, exp_tab): 'MAX_POOL_2D': self.convert_max_pool2d, 'CONCATENATION': self.convert_concatenation, 'ADD': self.convert_add, + 'MUL': self.convert_mul, 'FULLY_CONNECTED': self.convert_fully_connected, } @@ -267,8 +268,8 @@ def convert_concatenation(self, op): out = self.convert_fused_activation_function(out, fused_activation_fn) return out - def convert_add(self, op): - """Convert TFLite add""" + def _convert_elemwise(self, relay_op, op): + """Generic method to Convert TFLite elemwise""" try: from tflite.Operator import Operator except ImportError: @@ -283,19 +284,26 @@ def convert_add(self, op): rhs_tensor = input_tensors[1] if self.has_expr(rhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses ADD operators + # In most cases, we can assume that TOCO fuses elemwise operators # with constants - it means both will be tensors. rhs_expr = self.get_expr(rhs_tensor.tensor_idx) else: - # However, in some corner cases, the ADD operator is not fused, + # However, in some corner cases, the elemwise operator is not fused, # we can receive as constant. rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) - - out = _op.add(lhs_expr, rhs_expr) + out = relay_op(lhs_expr, rhs_expr) return out + def convert_add(self, op): + """Convert TFLite ADD""" + return self._convert_elemwise(_op.add, op) + + def convert_mul(self, op): + """Convert TFLite MUL""" + return self._convert_elemwise(_op.multiply, op) + def convert_fully_connected(self, op): """Convert TFLite fully connected""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 8fc2d550d556..677fbb87bf46 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -24,7 +24,6 @@ import numpy as np import tvm from tvm import relay -from tvm.contrib import util import tensorflow as tf from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -144,8 +143,6 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, for i in range(len(tflite_output)): tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) - sess.close() - ####################################################################### # Pooling @@ -311,10 +308,10 @@ def test_forward_concatenation(): ####################################################################### -# Add +# Element-wise # --- -def _test_add(data): +def _test_elemwise(math_op, data): """ One iteration of add """ assert len(data) == 2 @@ -329,10 +326,19 @@ def _test_add(data): # Test with tensor and constant with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] - out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) + out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) +####################################################################### +# Add +# --- + +def _test_add(data): + """ One iteration of add """ + return _test_elemwise(math_ops.add, data) + + def test_forward_add(): """ Add """ _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), @@ -343,6 +349,25 @@ def test_forward_add(): np.arange(3.0, dtype=np.float32).reshape((1, 3))]) +####################################################################### +# Mul +# --- + +def _test_mul(data): + """ One iteration of mul """ + return _test_elemwise(math_ops.multiply, data) + + +def test_forward_mul(): + """ Mul """ + _test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) + _test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), + np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) + _test_mul([np.arange(3.0, dtype=np.float32).reshape((1, 3)), + np.arange(3.0, dtype=np.float32).reshape((1, 3))]) + + ####################################################################### # Squeeze # ------- @@ -514,6 +539,7 @@ def test_forward_inception_v4_net(): # Math test_forward_add() + test_forward_mul() # End to End test_forward_mobilenet_v1() From 05ded204dbb5d16ac92a36c5d56229a3646825c4 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Sun, 9 Jun 2019 16:41:22 -0700 Subject: [PATCH 101/176] add another default location to verilator (#3324) --- vta/apps/tsim_example/cmake/modules/hw.cmake | 8 +++++++- vta/hardware/chisel/Makefile | 16 +++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/vta/apps/tsim_example/cmake/modules/hw.cmake b/vta/apps/tsim_example/cmake/modules/hw.cmake index 019be129f243..102df9987752 100644 --- a/vta/apps/tsim_example/cmake/modules/hw.cmake +++ b/vta/apps/tsim_example/cmake/modules/hw.cmake @@ -113,7 +113,13 @@ else() # Build shared library (.so) set(VTA_HW_DPI_DIR ${VTA_DIR}/hardware/dpi) - set(VERILATOR_INC_DIR /usr/local/share/verilator/include) + if (EXISTS /usr/local/share/verilator/include) + set(VERILATOR_INC_DIR /usr/local/share/verilator/include) + elseif (EXISTS /usr/share/verilator/include) + set(VERILATOR_INC_DIR /usr/share/verilator/include) + else() + message(FATAL_ERROR "[TSIM_HW] Verilator include directory not found") + endif() set(VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated.cpp ${VERILATOR_INC_DIR}/verilated_dpi.cpp) if (NOT TSIM_USE_TRACE STREQUAL "off") diff --git a/vta/hardware/chisel/Makefile b/vta/hardware/chisel/Makefile index 91e40a022337..7e90168c21c6 100644 --- a/vta/hardware/chisel/Makefile +++ b/vta/hardware/chisel/Makefile @@ -15,15 +15,21 @@ # specific language governing permissions and limitations # under the License. -# Change this variable if Verilator is installed on a different location -VERILATOR_INC_DIR ?= /usr/local/share/verilator/include - ifeq (, $(shell which verilator)) $(error "No Verilator in $(PATH), consider doing apt-get install verilator") endif -ifeq (, $(wildcard $(VERILATOR_INC_DIR)/*)) - $(error "Verilator include directory is not set properly") +# Change VERILATOR_INC_DIR if Verilator is installed on a different location +ifeq (, $(VERILATOR_INC_DIR)) + ifeq (, $(wildcard /usr/local/share/verilator/include/*)) + ifeq (, $(wildcard /usr/share/verilator/include/*)) + $(error "Verilator include directory is not set properly") + else + VERILATOR_INC_DIR := /usr/share/verilator/include + endif + else + VERILATOR_INC_DIR := /usr/local/share/verilator/include + endif endif CONFIG = DefaultF1Config From b35350efc8e536e869528e10f9847e36b2b1495e Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Mon, 10 Jun 2019 17:24:22 +0100 Subject: [PATCH 102/176] [DOC] minor gramatical improvements to tensor_expr_get_started (#3330) --- tutorials/tensor_expr_get_started.py | 84 ++++++++++++++-------------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/tutorials/tensor_expr_get_started.py b/tutorials/tensor_expr_get_started.py index a0b84f0e81ca..1b5eabcd56ea 100644 --- a/tutorials/tensor_expr_get_started.py +++ b/tutorials/tensor_expr_get_started.py @@ -19,7 +19,7 @@ ================================== **Author**: `Tianqi Chen `_ -This is an introduction tutorial to Tensor expression language in TVM. +This is an introductory tutorial to the Tensor expression language in TVM. TVM uses a domain specific tensor expression for efficient kernel construction. In this tutorial, we will demonstrate the basic workflow to use @@ -48,15 +48,16 @@ # ------------------------ # As a first step, we need to describe our computation. # TVM adopts tensor semantics, with each intermediate result -# represented as multi-dimensional array. The user need to describe -# the computation rule that generate the tensors. +# represented as a multi-dimensional array. The user needs to describe +# the computation rule that generates the tensors. # # We first define a symbolic variable n to represent the shape. # We then define two placeholder Tensors, A and B, with given shape (n,) # -# We then describe the result tensor C, with a compute operation. -# The compute function takes the shape of the tensor, as well as a lambda function -# that describes the computation rule for each position of the tensor. +# We then describe the result tensor C, with a compute operation. The +# compute function takes the shape of the tensor, as well as a lambda +# function that describes the computation rule for each position of +# the tensor. # # No computation happens during this phase, as we are only declaring how # the computation should be done. @@ -70,9 +71,10 @@ ###################################################################### # Schedule the Computation # ------------------------ -# While the above lines describes the computation rule, we can compute -# C in many ways since the axis of C can be computed in data parallel manner. -# TVM asks user to provide a description of computation called schedule. +# While the above lines describe the computation rule, we can compute +# C in many ways since the axis of C can be computed in a data +# parallel manner. TVM asks the user to provide a description of the +# computation called a schedule. # # A schedule is a set of transformation of computation that transforms # the loop of computations in the program. @@ -120,33 +122,33 @@ # ----------- # After we have finished specifying the schedule, we can compile it # into a TVM function. By default TVM compiles into a type-erased -# function that can be directly called from python side. +# function that can be directly called from the python side. # # In the following line, we use tvm.build to create a function. # The build function takes the schedule, the desired signature of the -# function(including the inputs and outputs) as well as target language +# function (including the inputs and outputs) as well as target language # we want to compile to. # -# The result of compilation fadd is a GPU device function(if GPU is involved) -# that can as well as a host wrapper that calls into the GPU function. -# fadd is the generated host wrapper function, it contains reference -# to the generated device function internally. +# The result of compilation fadd is a GPU device function (if GPU is +# involved) as well as a host wrapper that calls into the GPU +# function. fadd is the generated host wrapper function, it contains +# a reference to the generated device function internally. # fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") ###################################################################### # Run the Function # ---------------- -# The compiled function TVM function is designed to be a concise C API -# that can be invoked from any languages. +# The compiled TVM function is exposes a concise C API +# that can be invoked from any language. # -# We provide an minimum array API in python to aid quick testing and prototyping. -# The array API is based on `DLPack `_ standard. +# We provide a minimal array API in python to aid quick testing and prototyping. +# The array API is based on the `DLPack `_ standard. # # - We first create a GPU context. -# - Then tvm.nd.array copies the data to GPU. +# - Then tvm.nd.array copies the data to the GPU. # - fadd runs the actual computation. -# - asnumpy() copies the GPU array back to CPU and we can use this to verify correctness +# - asnumpy() copies the GPU array back to the CPU and we can use this to verify correctness # ctx = tvm.context(tgt, 0) @@ -176,14 +178,14 @@ ###################################################################### # .. note:: Code Specialization # -# As you may noticed, during the declaration, A, B and C both -# takes the same shape argument n. TVM will take advantage of this -# to pass only single shape argument to the kernel, as you will find in +# As you may have noticed, the declarations of A, B and C all +# take the same shape argument, n. TVM will take advantage of this +# to pass only a single shape argument to the kernel, as you will find in # the printed device code. This is one form of specialization. # # On the host side, TVM will automatically generate check code # that checks the constraints in the parameters. So if you pass -# arrays with different shapes into the fadd, an error will be raised. +# arrays with different shapes into fadd, an error will be raised. # # We can do more specializations. For example, we can write # :code:`n = tvm.convert(1024)` instead of :code:`n = tvm.var("n")`, @@ -195,13 +197,13 @@ # Save Compiled Module # -------------------- # Besides runtime compilation, we can save the compiled modules into -# file and load them back later. This is called ahead of time compilation. +# a file and load them back later. This is called ahead of time compilation. # -# The following code first does the following step: +# The following code first performs the following steps: # # - It saves the compiled host module into an object file. # - Then it saves the device module into a ptx file. -# - cc.create_shared calls a env compiler(gcc) to create a shared library +# - cc.create_shared calls a compiler (gcc) to create a shared library # from tvm.contrib import cc from tvm.contrib import util @@ -218,9 +220,9 @@ ###################################################################### # .. note:: Module Storage Format # -# The CPU(host) module is directly saved as a shared library(so). -# There can be multiple customized format on the device code. -# In our example, device code is stored in ptx, as well as a meta +# The CPU (host) module is directly saved as a shared library (.so). +# There can be multiple customized formats of the device code. +# In our example, the device code is stored in ptx, as well as a meta # data json file. They can be loaded and linked separately via import. # @@ -228,8 +230,8 @@ # Load Compiled Module # -------------------- # We can load the compiled module from the file system and run the code. -# The following code load the host and device module separately and -# re-link them together. We can verify that the newly loaded function works. +# The following code loads the host and device module separately and +# re-links them together. We can verify that the newly loaded function works. # fadd1 = tvm.module.load(temp.relpath("myadd.so")) if tgt == "cuda": @@ -261,11 +263,11 @@ # .. note:: Runtime API and Thread-Safety # # The compiled modules of TVM do not depend on the TVM compiler. -# Instead, it only depends on a minimum runtime library. -# TVM runtime library wraps the device drivers and provides -# thread-safe and device agnostic call into the compiled functions. +# Instead, they only depend on a minimum runtime library. +# The TVM runtime library wraps the device drivers and provides +# thread-safe and device agnostic calls into the compiled functions. # -# This means you can call the compiled TVM function from any thread, +# This means that you can call the compiled TVM functions from any thread, # on any GPUs. # @@ -275,7 +277,7 @@ # TVM provides code generation features into multiple backends, # we can also generate OpenCL code or LLVM code that runs on CPU backends. # -# The following codeblocks generate opencl code, creates array on opencl +# The following code blocks generate OpenCL code, creates array on an OpenCL # device, and verifies the correctness of the code. # if tgt.startswith('opencl'): @@ -296,12 +298,12 @@ # This tutorial provides a walk through of TVM workflow using # a vector add example. The general workflow is # -# - Describe your computation via series of operations. +# - Describe your computation via a series of operations. # - Describe how we want to compute use schedule primitives. # - Compile to the target function we want. # - Optionally, save the function to be loaded later. # -# You are more than welcomed to checkout other examples and -# tutorials to learn more about the supported operations, schedule primitives +# You are more than welcome to checkout other examples and +# tutorials to learn more about the supported operations, scheduling primitives # and other features in TVM. # From e19b899094a8c24e48c591391977da05faab24de Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Mon, 10 Jun 2019 19:01:10 +0100 Subject: [PATCH 103/176] Drop trailing whitespace (#3331) --- tests/scripts/task_build.sh | 4 ++-- tests/scripts/task_clean.sh | 4 ++-- tests/scripts/task_cpp_unittest.sh | 4 ++-- tests/scripts/task_golang.sh | 4 ++-- tests/scripts/task_python_docs.sh | 4 ++-- tests/scripts/task_python_integration.sh | 4 ++-- tests/scripts/task_python_topi.sh | 4 ++-- tests/scripts/task_python_unittest.sh | 4 ++-- tests/scripts/task_python_vta.sh | 4 ++-- tests/scripts/task_verilog_test.sh | 4 ++-- tests/scripts/task_web_build.sh | 4 ++-- tests/scripts/task_web_test.sh | 4 ++-- 12 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/scripts/task_build.sh b/tests/scripts/task_build.sh index 2440f0f6e79c..fbf3a63df0b2 100755 --- a/tests/scripts/task_build.sh +++ b/tests/scripts/task_build.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_clean.sh b/tests/scripts/task_clean.sh index bee75cebc742..d5e5f0728c1d 100755 --- a/tests/scripts/task_clean.sh +++ b/tests/scripts/task_clean.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_cpp_unittest.sh b/tests/scripts/task_cpp_unittest.sh index 793eb91b5d11..15a43b9b465e 100755 --- a/tests/scripts/task_cpp_unittest.sh +++ b/tests/scripts/task_cpp_unittest.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_golang.sh b/tests/scripts/task_golang.sh index 3e72756fcafe..ee9ec19c4201 100755 --- a/tests/scripts/task_golang.sh +++ b/tests/scripts/task_golang.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index b2a3fc1cb176..be511725e774 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 85dd6de64f6d..b472e96f9e18 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_python_topi.sh b/tests/scripts/task_python_topi.sh index a204f38c6cc6..59966cffc0bf 100755 --- a/tests/scripts/task_python_topi.sh +++ b/tests/scripts/task_python_topi.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index 7879c8d64e11..caa82f375ec8 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_python_vta.sh b/tests/scripts/task_python_vta.sh index 4345fc2ba39b..54fa6f06e39b 100755 --- a/tests/scripts/task_python_vta.sh +++ b/tests/scripts/task_python_vta.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_verilog_test.sh b/tests/scripts/task_verilog_test.sh index 8d725844bb5f..81da1c491a7e 100755 --- a/tests/scripts/task_verilog_test.sh +++ b/tests/scripts/task_verilog_test.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_web_build.sh b/tests/scripts/task_web_build.sh index 25854f5e6d21..ec1d15a04fbb 100755 --- a/tests/scripts/task_web_build.sh +++ b/tests/scripts/task_web_build.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/scripts/task_web_test.sh b/tests/scripts/task_web_test.sh index 4b383a1780aa..947a133c1a7b 100755 --- a/tests/scripts/task_web_test.sh +++ b/tests/scripts/task_web_test.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY From 9b3a340142bb0df5b3c2a9fd2c29c5be8a493608 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Mon, 10 Jun 2019 19:01:58 +0100 Subject: [PATCH 104/176] [CI] Fix shell script exit codes (#3329) The exist code of a posix compilant shell is 0..255. Attempting to return -1 will error in some shells and implicitly cast to 255 in others. Fix it by returning a legal return value. --- tests/scripts/task_lint.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 5116a42afb93..544ef7224770 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -41,7 +41,7 @@ if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then echo "- Create file_list.txt in your text editor" echo "- Copy paste the above content in file-list into file_list.txt" echo "- python3 tests/lint/add_asf_header.py file_list.txt" - exit -1 + exit 1 fi echo "Check codestyle of c++ code..." @@ -59,5 +59,5 @@ echo "---------Error Log----------" cat /tmp/$$.logclean.txt echo "----------------------------" if grep --quiet -E "warning|error" < /tmp/$$.logclean.txt; then - exit -1 + exit 1 fi From b54019fdd13088cad37b96fc82377a711fe3a684 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 10 Jun 2019 11:34:33 -0700 Subject: [PATCH 105/176] Add all parameters to from_tensorflow docs (#3321) --- nnvm/python/nnvm/frontend/tensorflow.py | 16 ++++++++++++++-- python/tvm/relay/frontend/tensorflow.py | 16 ++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 244b48eb3d5a..e59a4e76c465 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1188,7 +1188,7 @@ def __init__(self): self._input_shapes = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): - """Construct nnvm nodes from tensorflow graph definition - GraphDef. + """Construct nnvm nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to NNVM. Some of the assumptions listed below. @@ -1214,6 +1214,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): shape : Dictionary of input dimensions (Optional) Graph level input shape dictionary. + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- sym : nnvm.sym.Symbol @@ -1599,7 +1602,7 @@ def _fix_extranodes(self, op_name, attr, inputs): return inputs def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): - """ Load tensorflow graph which is a python tensorflow graph object into nnvm graph. + """Load tensorflow graph which is a python tensorflow graph object into nnvm graph. The companion parameters will be handled automatically. Parameters @@ -1607,6 +1610,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): graph : GraphDef object Tensorflow GraphDef + layout : target layout to be used (Optional) + NCHW only supported now to enable NHWC models on GPU. + + shape : Dictionary of input dimensions (Optional) + Graph level input shape dictionary. + + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- sym : nnvm.Symbol diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 45ae2cd19cd1..4f241952db2e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1787,7 +1787,7 @@ def __init__(self): self._branches = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): - """Construct relay nodes from tensorflow graph definition - GraphDef. + """Construct relay nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to Relay. Some of the assumptions listed below. @@ -1813,6 +1813,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): shape : Dictionary of input dimensions (Optional) Graph level input shape dictionary. + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- sym : relay.op @@ -2276,7 +2279,7 @@ def _convert_operator(self, op_name, inputs, attrs, def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): - """ Load tensorflow graph which is a python tensorflow graph object into relay. + """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. Parameters @@ -2284,6 +2287,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): graph : GraphDef object Tensorflow GraphDef + layout : target layout to be used (Optional) + NCHW only supported now to enable NHWC models on GPU. + + shape : Dictionary of input dimensions (Optional) + Graph level input shape dictionary. + + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- sym : relay.op From 317b98ae927f58a319f16c9f49e9d37870b94152 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 10 Jun 2019 11:34:52 -0700 Subject: [PATCH 106/176] Fix Error messages in tflite.py (#3320) --- nnvm/python/nnvm/frontend/keras.py | 1 - python/tvm/relay/frontend/keras.py | 1 - python/tvm/relay/frontend/tflite.py | 9 +++++---- tests/python/frontend/keras/test_forward.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index 7af8cf8833dd..f647a644bd2b 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -180,7 +180,6 @@ def _convert_convolution(insym, keras_layer, symtab): else: kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape weight = weightList[0].transpose([3, 2, 0, 1]) - dilation = [1, 1] if isinstance(keras_layer.dilation_rate, (list, tuple)): dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] else: diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 2648a5a6637b..5d5e50ff3559 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -203,7 +203,6 @@ def _convert_convolution(inexpr, keras_layer, etab): else: kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape weight = weightList[0].transpose([3, 2, 0, 1]) - dilation = [1, 1] if isinstance(keras_layer.dilation_rate, (list, tuple)): dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] else: diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ad2cd49eb12b..9c8f50f0b020 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -156,7 +156,7 @@ def get_tensor_value(self, tensor_wrapper): if tensor_wrapper.tensor.Type() == TensorType.INT32: return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( tensor_wrapper.tensor.ShapeAsNumpy()) - raise NotImplementedError("Not support tensor type {}" + raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_wrapper.tensor.Type()))) def get_tensor_type_str(self, tensor_type): @@ -172,7 +172,8 @@ def get_tensor_type_str(self, tensor_type): return "float32" if tensor_type == TensorType.INT32: return "int32" - raise NotImplementedError("Not support tensor type {}".format(str(tensor_type))) + raise NotImplementedError("Tensor type {} is currently not supported" + .format(str(tensor_type))) def convert_conv2d(self, op): """Convert TFLite conv2d""" @@ -450,8 +451,8 @@ def convert_conv(self, op, conv_type): conv_options = DepthwiseConv2DOptions() conv_options.Init(op_options.Bytes, op_options.Pos) depth_multiplier = conv_options.DepthMultiplier() - assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ - "no matter original value be set by 0.25, 0.5 or any else" + assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \ + "original value is set to 0.25, 0.5 or anything else" else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(conv_type)) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 8817d4faaeaa..0794db987892 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -21,7 +21,7 @@ from tvm.relay.testing.config import ctx_list import keras -# prevent keras from using up all gpu memory +# prevent Keras from using up all gpu memory import tensorflow as tf from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto() From c435ab9692d17d0a81db59ca7e8af7fc0156adcb Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Mon, 10 Jun 2019 13:08:28 -0700 Subject: [PATCH 107/176] Support x86 dilation conv2d and improve multi-batch conv2d (#3308) * Support x86 dilation conv2d and improve multi-batch conv2d * Fix lint --- topi/python/topi/x86/conv2d.py | 19 ++++++++++++------- topi/python/topi/x86/conv2d_avx_1x1.py | 12 ++++++------ topi/python/topi/x86/conv2d_avx_common.py | 12 ++++++------ topi/tests/python/test_topi_conv2d_NCHWc.py | 8 ++++---- tutorials/frontend/deploy_ssd_gluoncv.py | 5 ++--- 5 files changed, 30 insertions(+), 26 deletions(-) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 08becf428c27..fd3f19a5d060 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -505,8 +505,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, # we keep them for debug convenience when dumping autotvm workload HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) - assert (dh, dw) == (1, 1), "Does not support dilation" + dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ + else (dilation, dilation) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn @@ -519,6 +519,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + if cfg.is_fallback: _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), @@ -526,8 +529,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, strides, padding, out_dtype) # output shape - out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1 - out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1 + out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 oshape = (n, oc_chunk, out_height, out_width, oc_bn) # DOPAD @@ -553,8 +556,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw, - ic_f_inner * n_elems + ic_s_inner] + tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh*dilation_h, + ow*WSTR+kw*dilation_w, + ic_f_inner * n_elems + ic_s_inner] .astype(out_dtype) * kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(out_dtype), @@ -580,7 +584,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, + tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh*dilation_h, + ow*WSTR+kw*dilation_w, ic%ic_bn].astype(out_dtype) * kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], axis=[ic, kh, kw]), diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 4994d4580ab5..256cea569c68 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -134,7 +134,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): batch, ic_chunk, ih, iw, ic_block = s[A].op.axis - parallel_axis = s[A].fuse(ic_chunk, ih) + parallel_axis = s[A].fuse(batch, ic_chunk, ih) s[A].parallel(parallel_axis) C, O = conv_out, last @@ -146,7 +146,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) - parallel_axis = s[C].fuse(oc_chunk, oh_outer) + parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) s[CC].compute_at(s[C], parallel_axis) if C == O: s[C].parallel(parallel_axis) @@ -172,7 +172,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - parallel_axis = s[O].fuse(oc_chunk, oh_outer) + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) @@ -203,7 +203,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): batch, ic_chunk, ih, iw, ic_block = s[A].op.axis - parallel_axis = s[A].fuse(ic_chunk, ih) + parallel_axis = s[A].fuse(batch, ic_chunk, ih) s[A].parallel(parallel_axis) C, O = conv_out, last @@ -215,7 +215,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) - parallel_axis = s[C].fuse(oc_chunk, oh_outer) + parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) s[CC].compute_at(s[C], parallel_axis) if C == O: s[C].parallel(parallel_axis) @@ -246,7 +246,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - parallel_axis = s[O].fuse(oc_chunk, oh_outer) + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 3ab68d71b948..44867c9e33d5 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -143,10 +143,10 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): C, O = conv_out, last CC = s.cache_write(C, 'global') - _, oc_chunk, oh, ow, oc_block = s[C].op.axis + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[C].fuse(oc_chunk, oh) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) s[C].vectorize(oc_block) if C == O: s[C].parallel(parallel_axis) @@ -171,7 +171,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(oc_chunk, oh) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) @@ -214,10 +214,10 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): C, O = conv_out, last CC = s.cache_write(C, 'global') - _, oc_chunk, oh, ow, oc_block = s[C].op.axis + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[C].fuse(oc_chunk, oh) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) s[C].vectorize(oc_block) if C == O: s[C].parallel(parallel_axis) @@ -251,7 +251,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(oc_chunk, oh) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) diff --git a/topi/tests/python/test_topi_conv2d_NCHWc.py b/topi/tests/python/test_topi_conv2d_NCHWc.py index 5aca0c00c4d6..26b4642bd333 100644 --- a/topi/tests/python/test_topi_conv2d_NCHWc.py +++ b/topi/tests/python/test_topi_conv2d_NCHWc.py @@ -49,7 +49,6 @@ def _transform_bias(bias, bn): def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"): - assert dilation == 1, "conv2d_NCHWc does not support dilation for now." print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) @@ -79,7 +78,8 @@ def get_ref_data(): a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype) b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype) - c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) if add_bias: c_np += b_np if add_relu: @@ -149,8 +149,8 @@ def test_conv2d_NCHWc(): verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True) verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) - # disable dilation test since it is not supported by NCHW[x]c conv for now. - # verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2) + # dilation + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2) # batch size verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1) diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index f536679183c8..829957b3c658 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -47,9 +47,6 @@ # # To get best performance fo SSD on Intel graphics, # change target argument to 'opencl -device=intel_graphics' -# -# SSD with VGG as body network is not supported yet since -# x86 conv2d schedule doesn't support dilation. supported_model = [ 'ssd_512_resnet50_v1_voc', @@ -57,6 +54,8 @@ 'ssd_512_resnet101_v2_voc', 'ssd_512_mobilenet1.0_voc', 'ssd_512_mobilenet1.0_coco', + 'ssd_300_vgg16_atrous_voc' + 'ssd_512_vgg16_atrous_coco', ] model_name = supported_model[0] From a33e7aa62a56da4ea07395b5b8a71741ca698d44 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Mon, 10 Jun 2019 14:33:11 -0700 Subject: [PATCH 108/176] [Autotvm] Support override (#3292) --- python/tvm/autotvm/task/topi_integration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index ef0cb568071c..c48d4f58edce 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -284,7 +284,7 @@ def get(allow_duplicate=False): return TaskExtractEnv.current -def register_topi_compute(topi_compute, target_keys, template_keys, func=None): +def register_topi_compute(topi_compute, target_keys, template_keys, func=None, override=False): """Register a tunable template for a topi compute function. After the registration, this topi compute will become a configuration dispatcher. It uses @@ -333,7 +333,7 @@ def config_dispatcher(*args, **kwargs): config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute] - @config_dispatcher.register(template_keys) + @config_dispatcher.register(template_keys, override=override) def template_call(cfg, *args, **kwargs): """call the topi func and attach workload to compute node""" assert not kwargs, "Do not support kwargs in template function call" @@ -372,7 +372,7 @@ def template_call(cfg, *args, **kwargs): return _decorator -def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None): +def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None, override=False): """Register a tunable template for a topi schedule function. After the registration. This topi schedule will become a configuration dispatcher. It dispatches @@ -438,7 +438,7 @@ def traverse(tensors): config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule] - @config_dispatcher.register(template_keys) + @config_dispatcher.register(template_keys, override=override) def template_call(cfg, outs, *args, **kwargs): """call the schedule func""" if f == topi_schedule.fdefault: From a654cf15c3f024aaf5aad531b15e5fd08dcd2b46 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 10 Jun 2019 15:02:41 -0700 Subject: [PATCH 109/176] [Relay][heterogeneous] Fix tuple annotation (#3311) * [Relay][heterogeneous] Fix TupleGetItem * retrigger ci * retrigger ci --- src/relay/pass/device_annotation.cc | 14 ++++++-- tests/python/relay/test_pass_annotation.py | 38 +++++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 02d6d9e1fefb..8eeb493f1feb 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -68,6 +68,7 @@ class ValidateAnnotation : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { + ExprVisitor::VisitExpr_(call_node); if (IsOnDeviceNode(call_node)) { int device_type = GetDeviceId(call_node); if (annotation_map_.count(call_node)) { @@ -86,7 +87,14 @@ class ValidateAnnotation : private ExprVisitor { annotation_map_.insert({node, GetDeviceId(call_node)}); } } - ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const TupleGetItemNode* get_elem) final { + ExprVisitor::VisitExpr_(get_elem); + const auto* tn = get_elem->tuple.operator->(); + if (annotation_map_.count(tn)) { + annotation_map_.insert({get_elem, annotation_map_.at(tn)}); + } } /* @@ -253,7 +261,9 @@ class RewriteAnnotation : public ExprMutator { if (src->is_type() || src->is_type()) { return annotation_map_.at(dst) != fallback_device_; } else { - return false; + // There shouldn't be any copy nodes between var/constant and another + // expression. + return !(src->is_type() || src->is_type()); } } else { return false; diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index ba2c249693b7..84a5c8749079 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -554,6 +554,7 @@ def expected(): res = mod.get_output(0).asnumpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) + def test_check_run(): for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), ("opencl", str(tvm.target.intel_graphics()))]: @@ -563,7 +564,41 @@ def test_check_run(): run_fusible_network(dev, tgt) run_unpropagatable_graph(dev, tgt) - + +def test_tuple_get_item(): + dev = "cuda" + if not tvm.module.enabled(dev): + print("Skip test because %s is not enabled." % dev) + return + + cpu_ctx = tvm.cpu(0) + gpu_ctx = tvm.context(dev) + + def expected(): + x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32")) + split = relay.op.split(x, 3) + elem0 = relay.device_copy(split[0], gpu_ctx, cpu_ctx) + elem1 = relay.device_copy(split[1], gpu_ctx, cpu_ctx) + sub = elem0 - elem1 + func = relay.Function(relay.ir_pass.free_vars(sub), sub) + return func + + def annotated(): + x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32")) + split = relay.op.split(x, 3) + split = split.astuple() + split = relay.annotation.on_device(split, gpu_ctx) + split = relay.TupleWrapper(split, 3) + sub = split[0] - split[1] + func = relay.Function(relay.ir_pass.free_vars(sub), sub) + func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) + return func + + annotated_func = relay.ir_pass.infer_type(annotated()) + expected_func = relay.ir_pass.infer_type(expected()) + assert relay.ir_pass.graph_equal(annotated_func, expected_func) + + if __name__ == "__main__": test_redundant_annotation() test_annotate_expr() @@ -571,3 +606,4 @@ def test_check_run(): test_annotate_none() test_conv_network() test_check_run() + test_tuple_get_item() From e3ee1aafa3e6bf71b6a7452ef7498f4e07cfbce2 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 10 Jun 2019 15:29:01 -0700 Subject: [PATCH 110/176] Add PAD operator to relay tflite frontend (#3310) --- python/tvm/relay/frontend/tflite.py | 26 +++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 30 ++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9c8f50f0b020..3a13473202a3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab): 'ADD': self.convert_add, 'MUL': self.convert_mul, 'FULLY_CONNECTED': self.convert_fully_connected, + 'PAD': self.convert_pad, } def check_unsupported_ops(self): @@ -596,6 +597,31 @@ def convert_pool2d(self, op, pool_type): return out + def convert_pad(self, op): + """Convert TFLite PAD""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + # TFLite only support CONSTANT mode and does not support constant_values parameter. + # tensor + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # paddings + pad_list = self.get_tensor_value(input_tensors[1]) + # convert list of lists to tuple of tuples + paddings = tuple(tuple(l) for l in pad_list) + + # Use default pad_value 0 because TFLite does not support constant_values parameter + out = _op.nn.pad(in_expr, paddings) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 677fbb87bf46..7da2b851bb3f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -394,6 +394,35 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2]) _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3]) + +####################################################################### +# Pad +# --- + +def _test_pad(data): + """ One iteration of PAD """ + + assert len(data) == 2 + + # Test with tensor and constant + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] + out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) + compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) + + +def test_forward_pad(): + """ Pad """ + _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32)]) + _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), + np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32)]) + _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)), + np.array([[1, 1], [2, 2]], dtype=np.int32)]) + _test_pad([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)), + np.array([[1, 1], [2, 2]], dtype=np.int32)]) + + ####################################################################### # Softmax # ------- @@ -528,6 +557,7 @@ def test_forward_inception_v4_net(): if __name__ == '__main__': # Transforms test_forward_concatenation() + test_forward_pad() test_forward_reshape() test_forward_squeeze() From ac20b98fe59e682f4aaeb3fc5f6ed734aa668a27 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 10 Jun 2019 17:47:31 -0700 Subject: [PATCH 111/176] [relay][vm] move vm opt passes to pass manager (#3323) --- python/tvm/relay/backend/vm.py | 52 ++++++++----- src/relay/backend/vm/compiler.cc | 24 ++++-- src/relay/backend/vm/inline_primitives.cc | 92 ++++++++++++----------- src/relay/backend/vm/lambda_lift.cc | 80 ++++++++++---------- src/relay/pass/pass_manager.cc | 15 ++-- 5 files changed, 150 insertions(+), 113 deletions(-) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index bebadd167fe9..3b9946a3958d 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -20,24 +20,45 @@ Implements a Python interface to compiling and executing on the Relay VM. """ +import numpy as np + import tvm from tvm._ffi.function import Object -import numpy as np -from .. import ir_pass +from .. import transform from ..backend.interpreter import Executor -from ..expr import GlobalVar, Function, Expr +from ..expr import GlobalVar, Expr from . import _vm Object = Object -def optimize(expr, mod=None): - # TODO: We need to move this optimization code into the optimizer/pass manager - ck_expr = ir_pass.infer_type(expr, mod=mod) - simplified_expr = ir_pass.simplify_inference(ck_expr) - simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod) - fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod) - ck_fused = ir_pass.infer_type(fused_expr, mod=mod) - return ck_fused +def optimize(mod): + """Perform several optimizations on a module before executing it in the + Relay virtual machine. + + Parameters + ---------- + mod : tvm.relay.Module + The module to optimize. + + Returns + ------- + ret : tvm.relay.Module + The optimized module. + """ + main_func = mod[mod.entry_func] + + opt_passes = [] + if not main_func.params and isinstance(main_func.body, GlobalVar): + opt_passes.append(transform.EtaExpand()) + + opt_passes = opt_passes + [ + transform.SimplifyInference(), + transform.FuseOps(), + transform.InferType() + ] + + seq = transform.Sequential(opt_passes) + return seq(mod) def _convert(arg, cargs): if isinstance(arg, np.ndarray): @@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args): args: List[tvm.NDArray, np.ndarray] The arguments to evaluate. """ - main_func = mod[mod.entry_func] - - if not main_func.params and isinstance(main_func.body, GlobalVar): - main_func = ir_pass.eta_expand(main_func.body, mod) - - assert isinstance(main_func, Function) - main_func = optimize(mod[mod.entry_func], mod) - mod[mod.entry_func] = main_func + mod = optimize(mod) args = list(args) assert isinstance(args, list) cargs = convert(args) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index db98a9a9d3fd..07633fc346ec 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -38,15 +38,22 @@ namespace tvm { namespace relay { + +namespace transform { + +Pass LambdaLift(); +Pass InlinePrimitives(); + +} // namespace transform + namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; +using namespace relay::transform; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); -Module LambdaLift(const Module& module); -Module InlinePrimitives(const Module& module); template using NodeMap = std::unordered_map; @@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F } Module OptimizeModule(const Module& mod) { - ToANormalForm(mod->entry_func, mod); - InlinePrimitives(mod); - LambdaLift(mod); - return InlinePrimitives(mod); + transform::Sequential seq({transform::ToANormalForm(), + transform::InlinePrimitives(), + transform::LambdaLift(), + transform::InlinePrimitives()}); + auto pass_ctx = transform::PassContext::Create(); + tvm::With ctx(pass_ctx); + return seq(mod); } void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index b033a37e42b8..1e561f8a8214 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -37,6 +37,21 @@ namespace tvm { namespace relay { namespace vm { +// TODO(@jroesch): write verifier + +/* This pass will eliminate primitives which have been lifted by the ANF + * transform inlining them directly into call sites. + * + * This makes VM related code generation easier as the call target is always + * a primitive function. + * + * let prim = fn(...) { ... }; + * prim(...) + * + * will become: + * + * (fn(...) { ... })(...) + */ struct PrimitiveInliner : ExprMutator { Module module_; std::unordered_map var_map; @@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator { } } - Function Inline(const Function& func) { - DLOG(INFO) << "Before inlining primitives: " << std::endl - << "func= " << AsText(func, false) << std::endl; - - auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, - func->type_params, func->attrs); - - inlined = Downcast(DeadCodeElimination(inlined)); - - DLOG(INFO) << "After inlining primitives" << std::endl - << "after_func= " << AsText(inlined, false) << std::endl; - return inlined; + Module Inline() { + auto gvar_funcs = module_->functions; + for (auto pair : gvar_funcs) { + auto global = pair.first; + auto func = pair.second; + DLOG(INFO) << "Before inlining primitives: " << global + << std::endl << AsText(func, false); + + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(global, func, true); + + DLOG(INFO) << "After inlining primitives: " << global + << std::endl << AsText(func, false); + } + return module_; } }; -// TODO(@jroesch): write verifier - -/* This pass will eliminate primitives which have been lifted by the ANF - * transform inlining them directly into call sites. - * - * This makes VM related code generation easier as the call target is always - * a primitive function. - * - * let prim = fn(...) { ... }; - * prim(...) - * - * will become: - * - * (fn(...) { ... })(...) - */ -Module InlinePrimitives(const Module& module) { - PrimitiveInliner inliner(module); +} // namespace vm - tvm::Map updates; +namespace transform { - // There is an ordering bug here. - for (auto pair : module->functions) { - auto global = pair.first; - auto func = pair.second; - updates.Set(global, inliner.Inline(func)); - } +Pass InlinePrimitives() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::PrimitiveInliner(m).Inline(); + }; + auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); + // Eliminate dead code for each function after inlining. + return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); +} - for (auto pair : updates) { - module->Add(pair.first, pair.second, true); - } +TVM_REGISTER_API("relay._transform.InlinePrimitives") +.set_body_typed(InlinePrimitives); - return module; -} +} // namespace transform -} // namespace vm } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 13d8112440fb..a55a9273d078 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) { return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); } +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ struct LambdaLifter : ExprMutator { Module module_; - std::vector> lifted_; explicit LambdaLifter(const Module& module) : module_(module) {} Expr VisitExpr_(const FunctionNode* func_node) final { @@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator { auto free_type_vars = FreeTypeVars(func, module_); auto body = Downcast(ExprMutator::VisitExpr_(func_node)); - // When performing this optimization there are two - // cases. + // When performing this optimization there are two cases. // // The first case in which we have no free variables // we can just lift the function into the global @@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator { // // // The second case requires that we generate a special - // function with makes a distinction between allocating + // function which makes a distinction between allocating // a closure, and then the code for the closure. // // We represent a closure allocation by lifting the @@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator { // function marked as a closure is used to emit allocation // code for the closure's environment. // - // The "inner" function is should be used to generate the + // The "inner" function should be used to generate the // code for the closure. Function lifted_func; if (free_vars.size() == 0) { @@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator { CHECK(lifted_func.defined()); auto name = GenerateName(lifted_func); - auto global = this->module_->GetGlobalVar(name); + auto global = module_->GetGlobalVar(name); - lifted_.push_back({global, lifted_func}); + // Add the lifted function to the module. + module_->Add(global, lifted_func); if (free_vars.size() == 0) { return std::move(global); } else { - // If we need to allocate a closure - // we pass the variables in its environment - // here. + // If we need to allocate a closure, + // we pass the variables in its environment here. Array fvs; for (auto fv : free_vars) { fvs.push_back(fv); @@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator { } } - Function Lift(const Function& func) { - DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl; - return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, - func->type_params, func->attrs); + Module Lift() { + // There is an ordering bug here. + auto glob_funcs = module_->functions; + for (auto pair : glob_funcs) { + auto func = pair.second; + DLOG(INFO) << "Lifting " << AsText(func, false); + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(pair.first, func, true); + } + return module_; } }; -/* The goal of this pass is to lift out any nested functions into top-level - * functions. - * - * We will lift the functions out into globals which take the set of the free vars - * and then return a function whcih has b - */ -Module LambdaLift(const Module& module) { - LambdaLifter lifter(module); - - tvm::Map updates; +} // namespace vm - // There is an ordering bug here. - for (auto pair : module->functions) { - auto global = pair.first; - auto func = pair.second; - updates.Set(global, lifter.Lift(func)); - } +namespace transform { - for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) { - module->Add(i->first, i->second); - } +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::LambdaLifter(m).Lift(); + }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} - for (auto pair : updates) { - module->Add(pair.first, pair.second, true); - } +TVM_REGISTER_API("relay._transform.LambdaLift") +.set_body_typed(LambdaLift); - return module; -} +} // namespace transform -} // namespace vm } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 05eb43d6a653..782bb6a5980f 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); - DLOG(INFO) << "Executing module pass : " + DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; Module updated_mod = mod; - Module new_mod = ModuleNode::make({}, mod->type_definitions); // Execute the pass function and return a new module. + std::vector > updates; for (const auto& it : mod->functions) { auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, updated_mod, pass_ctx); - new_mod->Add(it.first, updated_func); + updates.push_back({it.first, updated_func}); + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); } - return new_mod; + return updated_mod; } // TODO(zhiics) Create an enum attribute for FunctionNode @@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) tvm::IRPrinter* p) { p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level) + p->stream << "\tfallback device: " + << runtime::DeviceName(node->fallback_device) << "\n"; p->stream << "\trequired passes: [" << node->opt_level; From 92dfe376a6a7f0a31adc1deba37754d3282b533c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 10 Jun 2019 18:15:11 -0700 Subject: [PATCH 112/176] [Relay][Prelude] Use the Relay parser to define the Relay prelude (#3043) * Add ability to load Prelude from disk * Port over id * Define compose * Linting errors and style changes * Eliminate unnecessary parens * Rename identType to typeIdent (makes more sense) * Another unnecessary paren * Bump the version number for the text format * Ensure .rly (Relay text files) are permitted * Correct release number and simplify grammar rule * Correct load_prelude docstring * Corrections to _parser * Add Apache headers to prelude source file * Remove test_prelude (redundant) * Correct misleading error message * Add check that parser is enabled in Prelude * Commit pre-generated parser, ensure generated files are treated as binaries, and have parser tests always fire * Permit parser files and git attributes files * Exclude gitattributes and parser files from apache check * Another attempt at appeasing Apache audit checker * Corrections to rat-excludes * Apache should be truly appeased now * Ignore Relay parser files by name * Mark parser files as generated so they don't show up on Github * Add parsing helper function for tests * Mark parser files as not detectable --- python/tvm/relay/_parser.py | 42 +- python/tvm/relay/grammar/Relay.g4 | 19 +- python/tvm/relay/grammar/py2/.gitattributes | 3 + python/tvm/relay/grammar/py2/.gitignore | 1 - python/tvm/relay/grammar/py2/Relay.interp | 109 + python/tvm/relay/grammar/py2/Relay.tokens | 70 + .../tvm/relay/grammar/py2/RelayLexer.interp | 140 + python/tvm/relay/grammar/py2/RelayLexer.py | 209 ++ .../tvm/relay/grammar/py2/RelayLexer.tokens | 70 + python/tvm/relay/grammar/py2/RelayParser.py | 2311 +++++++++++++++++ python/tvm/relay/grammar/py2/RelayVisitor.py | 192 ++ python/tvm/relay/grammar/py3/.gitattributes | 3 + python/tvm/relay/grammar/py3/.gitignore | 1 - python/tvm/relay/grammar/py3/Relay.interp | 109 + python/tvm/relay/grammar/py3/Relay.tokens | 70 + .../tvm/relay/grammar/py3/RelayLexer.interp | 140 + python/tvm/relay/grammar/py3/RelayLexer.py | 203 ++ .../tvm/relay/grammar/py3/RelayLexer.tokens | 70 + python/tvm/relay/grammar/py3/RelayParser.py | 2307 ++++++++++++++++ python/tvm/relay/grammar/py3/RelayVisitor.py | 198 ++ python/tvm/relay/parser.py | 13 - python/tvm/relay/prelude.py | 51 +- python/tvm/relay/prelude.rly | 29 + tests/lint/check_file_type.py | 8 +- tests/lint/rat-excludes | 8 + tests/python/relay/test_ir_parser.py | 152 +- 26 files changed, 6380 insertions(+), 148 deletions(-) create mode 100644 python/tvm/relay/grammar/py2/.gitattributes delete mode 100644 python/tvm/relay/grammar/py2/.gitignore create mode 100644 python/tvm/relay/grammar/py2/Relay.interp create mode 100644 python/tvm/relay/grammar/py2/Relay.tokens create mode 100644 python/tvm/relay/grammar/py2/RelayLexer.interp create mode 100644 python/tvm/relay/grammar/py2/RelayLexer.py create mode 100644 python/tvm/relay/grammar/py2/RelayLexer.tokens create mode 100644 python/tvm/relay/grammar/py2/RelayParser.py create mode 100644 python/tvm/relay/grammar/py2/RelayVisitor.py create mode 100644 python/tvm/relay/grammar/py3/.gitattributes delete mode 100644 python/tvm/relay/grammar/py3/.gitignore create mode 100644 python/tvm/relay/grammar/py3/Relay.interp create mode 100644 python/tvm/relay/grammar/py3/Relay.tokens create mode 100644 python/tvm/relay/grammar/py3/RelayLexer.interp create mode 100644 python/tvm/relay/grammar/py3/RelayLexer.py create mode 100644 python/tvm/relay/grammar/py3/RelayLexer.tokens create mode 100644 python/tvm/relay/grammar/py3/RelayParser.py create mode 100644 python/tvm/relay/grammar/py3/RelayVisitor.py create mode 100644 python/tvm/relay/prelude.rly diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 62f0ffe15cba..303f694896a5 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -242,10 +242,12 @@ def visitProg(self, ctx): self.visit_list(ctx.defn()) return self.module - return self.visit(ctx.expr()) + if ctx.expr(): + return self.visit(ctx.expr()) - # Exprs + return self.module + # Exprs def visitOpIdent(self, ctx): # type: (RelayParser.OpIdentContext) -> op.Op return op.get(ctx.CNAME().getText()) @@ -368,14 +370,25 @@ def mk_func(self, ctx): self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() + type_params = ctx.typeParamSeq() + + if type_params is not None: + type_params = type_params.ident() + assert type_params + for ty_param in type_params: + name = ty_param.getText() + self.mk_typ(name, ty.Kind.Type) + var_list, attr_list = self.visit(ctx.argList()) ret_type = self.getType_(ctx.type_()) + body = self.visit(ctx.body()) + # NB(@jroesch): you must stay in the type parameter scope until + # after you exit the body, you can reference the type parameters + # of your parent scopes. type_params = list(self.exit_type_param_scope()) if type_params: _, type_params = zip(*type_params) - - body = self.visit(ctx.body()) self.exit_var_scope() attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None @@ -453,16 +466,23 @@ def visitIncompleteType(self, ctx): # type (RelayParser.IncompleteTypeContext) -> None: return None - def visitIdentType(self, ctx): - # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] - ident_type = ctx.CNAME().getText() + def visitTypeIdent(self, ctx): + # type: (RelayParser.TypeIdentContext) -> Union[ty.TensorType, str] + ''' + Handle type identifier. + ''' + type_ident = ctx.CNAME().getText() - # look through all type prefixes for a match + # Look through all type prefixes for a match for type_prefix in TYPE_PREFIXES: - if ident_type.startswith(type_prefix): - return ty.scalar_type(ident_type) + if type_ident.startswith(type_prefix): + return ty.scalar_type(type_ident) + + type_param = lookup(self.type_param_scopes, type_ident) + if type_param is not None: + return type_param - raise ParseError("Unknown builtin type: {}".format(ident_type)) + raise ParseError("Unknown builtin type: {}".format(type_ident)) # def visitCallType(self, ctx): # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 58546439e1ce..97b4ea24a8b2 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -19,7 +19,7 @@ grammar Relay; -SEMVER: 'v0.0.1' ; +SEMVER: 'v0.0.2' ; // Lexing // comments @@ -111,8 +111,8 @@ expr // | 'debug' # debug ; -func: 'fn' '(' argList ')' ('->' type_)? body ; -defn: 'def' ident '(' argList ')' ('->' type_)? body ; +func: 'fn' typeParamSeq? '(' argList ')' ('->' type_)? body ; +defn: 'def' ident typeParamSeq? '(' argList ')' ('->' type_)? body ; argList : varList @@ -132,15 +132,20 @@ attr: CNAME '=' expr ; // relations: 'where' relation (',' relation)* ; // relation: ident '(' (type_ (',' type_)*)? ')' ; +typeParamSeq + : '[' ']' + | '[' ident (',' ident)* ']' + ; + type_ : '(' ')' # tupleType | '(' type_ ',' ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType - | identType # identTypeType + | typeIdent # typeIdentType | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType // currently unused - // | identType '[' (type_ (',' type_)*)? ']' # callType - | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType + // | typeIdent '[' (type_ (',' type_)*)? ']' # callType + | 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | '_' # incompleteType | NAT # intType ; @@ -158,7 +163,7 @@ shape | NAT # intShape ; -identType: CNAME ; +typeIdent : CNAME ; // int8, int16, int32, int64 // uint8, uint16, uint32, uint64 // float16, float32, float64 diff --git a/python/tvm/relay/grammar/py2/.gitattributes b/python/tvm/relay/grammar/py2/.gitattributes new file mode 100644 index 000000000000..4adf65fa2f3c --- /dev/null +++ b/python/tvm/relay/grammar/py2/.gitattributes @@ -0,0 +1,3 @@ +Relay* binary +Relay* linguist-generated=true +Relay* linguist-detectable=false \ No newline at end of file diff --git a/python/tvm/relay/grammar/py2/.gitignore b/python/tvm/relay/grammar/py2/.gitignore deleted file mode 100644 index d677ff551940..000000000000 --- a/python/tvm/relay/grammar/py2/.gitignore +++ /dev/null @@ -1 +0,0 @@ -Relay* diff --git a/python/tvm/relay/grammar/py2/Relay.interp b/python/tvm/relay/grammar/py2/Relay.interp new file mode 100644 index 000000000000..c6893d096168 --- /dev/null +++ b/python/tvm/relay/grammar/py2/Relay.interp @@ -0,0 +1,109 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +opIdent +prog +expr +func +defn +argList +varList +var +attrList +attr +typeParamSeq +type_ +shapeSeq +shape +typeIdent +body +scalar +ident + + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 42, 332, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 3, 2, 3, 2, 3, 3, 3, 3, 7, 3, 43, 10, 3, 12, 3, 14, 3, 46, 11, 3, 3, 3, 5, 3, 49, 10, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 6, 4, 72, 10, 4, 13, 4, 14, 4, 73, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 82, 10, 4, 12, 4, 14, 4, 85, 11, 4, 5, 4, 87, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 100, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 110, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 128, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 150, 10, 4, 12, 4, 14, 4, 153, 11, 4, 5, 4, 155, 10, 4, 3, 4, 7, 4, 158, 10, 4, 12, 4, 14, 4, 161, 11, 4, 3, 5, 3, 5, 5, 5, 165, 10, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 5, 5, 172, 10, 5, 3, 5, 3, 5, 3, 6, 3, 6, 3, 6, 5, 6, 179, 10, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 5, 6, 186, 10, 6, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 5, 7, 196, 10, 7, 3, 8, 3, 8, 3, 8, 7, 8, 201, 10, 8, 12, 8, 14, 8, 204, 11, 8, 5, 8, 206, 10, 8, 3, 9, 3, 9, 3, 9, 5, 9, 211, 10, 9, 3, 10, 3, 10, 3, 10, 7, 10, 216, 10, 10, 12, 10, 14, 10, 219, 11, 10, 5, 10, 221, 10, 10, 3, 11, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 7, 12, 233, 10, 12, 12, 12, 14, 12, 236, 11, 12, 3, 12, 3, 12, 5, 12, 240, 10, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 6, 13, 253, 10, 13, 13, 13, 14, 13, 254, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 269, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 7, 13, 275, 10, 13, 12, 13, 14, 13, 278, 11, 13, 5, 13, 280, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 287, 10, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 6, 14, 300, 10, 14, 13, 14, 14, 14, 301, 3, 14, 3, 14, 5, 14, 306, 10, 14, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 5, 15, 313, 10, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 5, 18, 324, 10, 18, 3, 19, 3, 19, 3, 19, 3, 19, 5, 19, 330, 10, 19, 3, 19, 2, 3, 6, 20, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 2, 6, 3, 2, 25, 26, 3, 2, 27, 28, 3, 2, 29, 32, 3, 2, 33, 34, 2, 373, 2, 38, 3, 2, 2, 2, 4, 40, 3, 2, 2, 2, 6, 127, 3, 2, 2, 2, 8, 162, 3, 2, 2, 2, 10, 175, 3, 2, 2, 2, 12, 195, 3, 2, 2, 2, 14, 205, 3, 2, 2, 2, 16, 207, 3, 2, 2, 2, 18, 220, 3, 2, 2, 2, 20, 222, 3, 2, 2, 2, 22, 239, 3, 2, 2, 2, 24, 286, 3, 2, 2, 2, 26, 305, 3, 2, 2, 2, 28, 312, 3, 2, 2, 2, 30, 314, 3, 2, 2, 2, 32, 316, 3, 2, 2, 2, 34, 323, 3, 2, 2, 2, 36, 329, 3, 2, 2, 2, 38, 39, 7, 42, 2, 2, 39, 3, 3, 2, 2, 2, 40, 48, 7, 21, 2, 2, 41, 43, 5, 10, 6, 2, 42, 41, 3, 2, 2, 2, 43, 46, 3, 2, 2, 2, 44, 42, 3, 2, 2, 2, 44, 45, 3, 2, 2, 2, 45, 49, 3, 2, 2, 2, 46, 44, 3, 2, 2, 2, 47, 49, 5, 6, 4, 2, 48, 44, 3, 2, 2, 2, 48, 47, 3, 2, 2, 2, 49, 50, 3, 2, 2, 2, 50, 51, 7, 2, 2, 3, 51, 5, 3, 2, 2, 2, 52, 53, 8, 4, 1, 2, 53, 54, 7, 3, 2, 2, 54, 55, 5, 6, 4, 2, 55, 56, 7, 4, 2, 2, 56, 128, 3, 2, 2, 2, 57, 58, 7, 28, 2, 2, 58, 128, 5, 6, 4, 19, 59, 128, 5, 8, 5, 2, 60, 61, 7, 3, 2, 2, 61, 128, 7, 4, 2, 2, 62, 63, 7, 3, 2, 2, 63, 64, 5, 6, 4, 2, 64, 65, 7, 5, 2, 2, 65, 66, 7, 4, 2, 2, 66, 128, 3, 2, 2, 2, 67, 68, 7, 3, 2, 2, 68, 71, 5, 6, 4, 2, 69, 70, 7, 5, 2, 2, 70, 72, 5, 6, 4, 2, 71, 69, 3, 2, 2, 2, 72, 73, 3, 2, 2, 2, 73, 71, 3, 2, 2, 2, 73, 74, 3, 2, 2, 2, 74, 75, 3, 2, 2, 2, 75, 76, 7, 4, 2, 2, 76, 128, 3, 2, 2, 2, 77, 86, 7, 6, 2, 2, 78, 83, 5, 6, 4, 2, 79, 80, 7, 5, 2, 2, 80, 82, 5, 6, 4, 2, 81, 79, 3, 2, 2, 2, 82, 85, 3, 2, 2, 2, 83, 81, 3, 2, 2, 2, 83, 84, 3, 2, 2, 2, 84, 87, 3, 2, 2, 2, 85, 83, 3, 2, 2, 2, 86, 78, 3, 2, 2, 2, 86, 87, 3, 2, 2, 2, 87, 88, 3, 2, 2, 2, 88, 128, 7, 7, 2, 2, 89, 90, 7, 8, 2, 2, 90, 91, 7, 3, 2, 2, 91, 92, 5, 6, 4, 2, 92, 93, 7, 4, 2, 2, 93, 94, 5, 32, 17, 2, 94, 95, 7, 9, 2, 2, 95, 96, 5, 32, 17, 2, 96, 128, 3, 2, 2, 2, 97, 99, 7, 10, 2, 2, 98, 100, 7, 38, 2, 2, 99, 98, 3, 2, 2, 2, 99, 100, 3, 2, 2, 2, 100, 101, 3, 2, 2, 2, 101, 102, 5, 16, 9, 2, 102, 103, 7, 11, 2, 2, 103, 104, 5, 6, 4, 2, 104, 105, 7, 12, 2, 2, 105, 106, 5, 6, 4, 8, 106, 128, 3, 2, 2, 2, 107, 109, 7, 10, 2, 2, 108, 110, 7, 38, 2, 2, 109, 108, 3, 2, 2, 2, 109, 110, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, 112, 5, 16, 9, 2, 112, 113, 7, 11, 2, 2, 113, 114, 7, 13, 2, 2, 114, 115, 5, 6, 4, 2, 115, 116, 7, 14, 2, 2, 116, 117, 7, 12, 2, 2, 117, 118, 5, 6, 4, 7, 118, 128, 3, 2, 2, 2, 119, 120, 5, 36, 19, 2, 120, 121, 7, 11, 2, 2, 121, 122, 5, 6, 4, 2, 122, 123, 7, 12, 2, 2, 123, 124, 5, 6, 4, 5, 124, 128, 3, 2, 2, 2, 125, 128, 5, 36, 19, 2, 126, 128, 5, 34, 18, 2, 127, 52, 3, 2, 2, 2, 127, 57, 3, 2, 2, 2, 127, 59, 3, 2, 2, 2, 127, 60, 3, 2, 2, 2, 127, 62, 3, 2, 2, 2, 127, 67, 3, 2, 2, 2, 127, 77, 3, 2, 2, 2, 127, 89, 3, 2, 2, 2, 127, 97, 3, 2, 2, 2, 127, 107, 3, 2, 2, 2, 127, 119, 3, 2, 2, 2, 127, 125, 3, 2, 2, 2, 127, 126, 3, 2, 2, 2, 128, 159, 3, 2, 2, 2, 129, 130, 12, 18, 2, 2, 130, 131, 9, 2, 2, 2, 131, 158, 5, 6, 4, 19, 132, 133, 12, 17, 2, 2, 133, 134, 9, 3, 2, 2, 134, 158, 5, 6, 4, 18, 135, 136, 12, 16, 2, 2, 136, 137, 9, 4, 2, 2, 137, 158, 5, 6, 4, 17, 138, 139, 12, 15, 2, 2, 139, 140, 9, 5, 2, 2, 140, 158, 5, 6, 4, 16, 141, 142, 12, 6, 2, 2, 142, 143, 7, 12, 2, 2, 143, 158, 5, 6, 4, 7, 144, 145, 12, 20, 2, 2, 145, 154, 7, 3, 2, 2, 146, 151, 5, 6, 4, 2, 147, 148, 7, 5, 2, 2, 148, 150, 5, 6, 4, 2, 149, 147, 3, 2, 2, 2, 150, 153, 3, 2, 2, 2, 151, 149, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 155, 3, 2, 2, 2, 153, 151, 3, 2, 2, 2, 154, 146, 3, 2, 2, 2, 154, 155, 3, 2, 2, 2, 155, 156, 3, 2, 2, 2, 156, 158, 7, 4, 2, 2, 157, 129, 3, 2, 2, 2, 157, 132, 3, 2, 2, 2, 157, 135, 3, 2, 2, 2, 157, 138, 3, 2, 2, 2, 157, 141, 3, 2, 2, 2, 157, 144, 3, 2, 2, 2, 158, 161, 3, 2, 2, 2, 159, 157, 3, 2, 2, 2, 159, 160, 3, 2, 2, 2, 160, 7, 3, 2, 2, 2, 161, 159, 3, 2, 2, 2, 162, 164, 7, 15, 2, 2, 163, 165, 5, 22, 12, 2, 164, 163, 3, 2, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 3, 2, 2, 2, 166, 167, 7, 3, 2, 2, 167, 168, 5, 12, 7, 2, 168, 171, 7, 4, 2, 2, 169, 170, 7, 16, 2, 2, 170, 172, 5, 24, 13, 2, 171, 169, 3, 2, 2, 2, 171, 172, 3, 2, 2, 2, 172, 173, 3, 2, 2, 2, 173, 174, 5, 32, 17, 2, 174, 9, 3, 2, 2, 2, 175, 176, 7, 17, 2, 2, 176, 178, 5, 36, 19, 2, 177, 179, 5, 22, 12, 2, 178, 177, 3, 2, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 3, 2, 2, 2, 180, 181, 7, 3, 2, 2, 181, 182, 5, 12, 7, 2, 182, 185, 7, 4, 2, 2, 183, 184, 7, 16, 2, 2, 184, 186, 5, 24, 13, 2, 185, 183, 3, 2, 2, 2, 185, 186, 3, 2, 2, 2, 186, 187, 3, 2, 2, 2, 187, 188, 5, 32, 17, 2, 188, 11, 3, 2, 2, 2, 189, 196, 5, 14, 8, 2, 190, 196, 5, 18, 10, 2, 191, 192, 5, 14, 8, 2, 192, 193, 7, 5, 2, 2, 193, 194, 5, 18, 10, 2, 194, 196, 3, 2, 2, 2, 195, 189, 3, 2, 2, 2, 195, 190, 3, 2, 2, 2, 195, 191, 3, 2, 2, 2, 196, 13, 3, 2, 2, 2, 197, 202, 5, 16, 9, 2, 198, 199, 7, 5, 2, 2, 199, 201, 5, 16, 9, 2, 200, 198, 3, 2, 2, 2, 201, 204, 3, 2, 2, 2, 202, 200, 3, 2, 2, 2, 202, 203, 3, 2, 2, 2, 203, 206, 3, 2, 2, 2, 204, 202, 3, 2, 2, 2, 205, 197, 3, 2, 2, 2, 205, 206, 3, 2, 2, 2, 206, 15, 3, 2, 2, 2, 207, 210, 5, 36, 19, 2, 208, 209, 7, 18, 2, 2, 209, 211, 5, 24, 13, 2, 210, 208, 3, 2, 2, 2, 210, 211, 3, 2, 2, 2, 211, 17, 3, 2, 2, 2, 212, 217, 5, 20, 11, 2, 213, 214, 7, 5, 2, 2, 214, 216, 5, 20, 11, 2, 215, 213, 3, 2, 2, 2, 216, 219, 3, 2, 2, 2, 217, 215, 3, 2, 2, 2, 217, 218, 3, 2, 2, 2, 218, 221, 3, 2, 2, 2, 219, 217, 3, 2, 2, 2, 220, 212, 3, 2, 2, 2, 220, 221, 3, 2, 2, 2, 221, 19, 3, 2, 2, 2, 222, 223, 7, 42, 2, 2, 223, 224, 7, 11, 2, 2, 224, 225, 5, 6, 4, 2, 225, 21, 3, 2, 2, 2, 226, 227, 7, 6, 2, 2, 227, 240, 7, 7, 2, 2, 228, 229, 7, 6, 2, 2, 229, 234, 5, 36, 19, 2, 230, 231, 7, 5, 2, 2, 231, 233, 5, 36, 19, 2, 232, 230, 3, 2, 2, 2, 233, 236, 3, 2, 2, 2, 234, 232, 3, 2, 2, 2, 234, 235, 3, 2, 2, 2, 235, 237, 3, 2, 2, 2, 236, 234, 3, 2, 2, 2, 237, 238, 7, 7, 2, 2, 238, 240, 3, 2, 2, 2, 239, 226, 3, 2, 2, 2, 239, 228, 3, 2, 2, 2, 240, 23, 3, 2, 2, 2, 241, 242, 7, 3, 2, 2, 242, 287, 7, 4, 2, 2, 243, 244, 7, 3, 2, 2, 244, 245, 5, 24, 13, 2, 245, 246, 7, 5, 2, 2, 246, 247, 7, 4, 2, 2, 247, 287, 3, 2, 2, 2, 248, 249, 7, 3, 2, 2, 249, 252, 5, 24, 13, 2, 250, 251, 7, 5, 2, 2, 251, 253, 5, 24, 13, 2, 252, 250, 3, 2, 2, 2, 253, 254, 3, 2, 2, 2, 254, 252, 3, 2, 2, 2, 254, 255, 3, 2, 2, 2, 255, 256, 3, 2, 2, 2, 256, 257, 7, 4, 2, 2, 257, 287, 3, 2, 2, 2, 258, 287, 5, 30, 16, 2, 259, 260, 7, 19, 2, 2, 260, 261, 7, 6, 2, 2, 261, 262, 5, 26, 14, 2, 262, 263, 7, 5, 2, 2, 263, 264, 5, 24, 13, 2, 264, 265, 7, 7, 2, 2, 265, 287, 3, 2, 2, 2, 266, 268, 7, 15, 2, 2, 267, 269, 5, 22, 12, 2, 268, 267, 3, 2, 2, 2, 268, 269, 3, 2, 2, 2, 269, 270, 3, 2, 2, 2, 270, 279, 7, 3, 2, 2, 271, 276, 5, 24, 13, 2, 272, 273, 7, 5, 2, 2, 273, 275, 5, 24, 13, 2, 274, 272, 3, 2, 2, 2, 275, 278, 3, 2, 2, 2, 276, 274, 3, 2, 2, 2, 276, 277, 3, 2, 2, 2, 277, 280, 3, 2, 2, 2, 278, 276, 3, 2, 2, 2, 279, 271, 3, 2, 2, 2, 279, 280, 3, 2, 2, 2, 280, 281, 3, 2, 2, 2, 281, 282, 7, 4, 2, 2, 282, 283, 7, 16, 2, 2, 283, 287, 5, 24, 13, 2, 284, 287, 7, 20, 2, 2, 285, 287, 7, 41, 2, 2, 286, 241, 3, 2, 2, 2, 286, 243, 3, 2, 2, 2, 286, 248, 3, 2, 2, 2, 286, 258, 3, 2, 2, 2, 286, 259, 3, 2, 2, 2, 286, 266, 3, 2, 2, 2, 286, 284, 3, 2, 2, 2, 286, 285, 3, 2, 2, 2, 287, 25, 3, 2, 2, 2, 288, 289, 7, 3, 2, 2, 289, 306, 7, 4, 2, 2, 290, 291, 7, 3, 2, 2, 291, 292, 5, 28, 15, 2, 292, 293, 7, 5, 2, 2, 293, 294, 7, 4, 2, 2, 294, 306, 3, 2, 2, 2, 295, 296, 7, 3, 2, 2, 296, 299, 5, 28, 15, 2, 297, 298, 7, 5, 2, 2, 298, 300, 5, 28, 15, 2, 299, 297, 3, 2, 2, 2, 300, 301, 3, 2, 2, 2, 301, 299, 3, 2, 2, 2, 301, 302, 3, 2, 2, 2, 302, 303, 3, 2, 2, 2, 303, 304, 7, 4, 2, 2, 304, 306, 3, 2, 2, 2, 305, 288, 3, 2, 2, 2, 305, 290, 3, 2, 2, 2, 305, 295, 3, 2, 2, 2, 306, 27, 3, 2, 2, 2, 307, 308, 7, 3, 2, 2, 308, 309, 5, 28, 15, 2, 309, 310, 7, 4, 2, 2, 310, 313, 3, 2, 2, 2, 311, 313, 7, 41, 2, 2, 312, 307, 3, 2, 2, 2, 312, 311, 3, 2, 2, 2, 313, 29, 3, 2, 2, 2, 314, 315, 7, 42, 2, 2, 315, 31, 3, 2, 2, 2, 316, 317, 7, 13, 2, 2, 317, 318, 5, 6, 4, 2, 318, 319, 7, 14, 2, 2, 319, 33, 3, 2, 2, 2, 320, 324, 7, 40, 2, 2, 321, 324, 7, 41, 2, 2, 322, 324, 7, 39, 2, 2, 323, 320, 3, 2, 2, 2, 323, 321, 3, 2, 2, 2, 323, 322, 3, 2, 2, 2, 324, 35, 3, 2, 2, 2, 325, 330, 5, 2, 2, 2, 326, 330, 7, 35, 2, 2, 327, 330, 7, 36, 2, 2, 328, 330, 7, 37, 2, 2, 329, 325, 3, 2, 2, 2, 329, 326, 3, 2, 2, 2, 329, 327, 3, 2, 2, 2, 329, 328, 3, 2, 2, 2, 330, 37, 3, 2, 2, 2, 36, 44, 48, 73, 83, 86, 99, 109, 127, 151, 154, 157, 159, 164, 171, 178, 185, 195, 202, 205, 210, 217, 220, 234, 239, 254, 268, 276, 279, 286, 301, 305, 312, 323, 329] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py2/Relay.tokens b/python/tvm/relay/grammar/py2/Relay.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py2/Relay.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py2/RelayLexer.interp b/python/tvm/relay/grammar/py2/RelayLexer.interp new file mode 100644 index 000000000000..092b3589ab70 --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayLexer.interp @@ -0,0 +1,140 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +T__0 +T__1 +T__2 +T__3 +T__4 +T__5 +T__6 +T__7 +T__8 +T__9 +T__10 +T__11 +T__12 +T__13 +T__14 +T__15 +T__16 +T__17 +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +EXP +CNAME +LETTER +DIGIT + +channel names: +DEFAULT_TOKEN_CHANNEL +HIDDEN + +mode names: +DEFAULT_MODE + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 42, 267, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 11, 3, 11, 3, 12, 3, 12, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 6, 21, 149, 10, 21, 13, 21, 14, 21, 150, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 22, 7, 22, 159, 10, 22, 12, 22, 14, 22, 162, 11, 22, 3, 22, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 23, 3, 23, 7, 23, 172, 10, 23, 12, 23, 14, 23, 175, 11, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 34, 3, 34, 3, 34, 3, 35, 3, 35, 3, 35, 3, 36, 3, 36, 3, 36, 3, 37, 3, 37, 3, 37, 3, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 228, 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 234, 10, 39, 3, 39, 3, 39, 3, 39, 5, 39, 239, 10, 39, 3, 40, 6, 40, 242, 10, 40, 13, 40, 14, 40, 243, 3, 41, 3, 41, 5, 41, 248, 10, 41, 3, 41, 3, 41, 3, 42, 3, 42, 5, 42, 254, 10, 42, 3, 42, 3, 42, 3, 42, 7, 42, 259, 10, 42, 12, 42, 14, 42, 262, 11, 42, 3, 43, 3, 43, 3, 44, 3, 44, 4, 160, 173, 2, 45, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 2, 83, 42, 85, 2, 87, 2, 3, 2, 7, 5, 2, 11, 12, 15, 15, 34, 34, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 67, 92, 99, 124, 3, 2, 50, 59, 2, 275, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 83, 3, 2, 2, 2, 3, 89, 3, 2, 2, 2, 5, 91, 3, 2, 2, 2, 7, 93, 3, 2, 2, 2, 9, 95, 3, 2, 2, 2, 11, 97, 3, 2, 2, 2, 13, 99, 3, 2, 2, 2, 15, 102, 3, 2, 2, 2, 17, 107, 3, 2, 2, 2, 19, 111, 3, 2, 2, 2, 21, 113, 3, 2, 2, 2, 23, 115, 3, 2, 2, 2, 25, 117, 3, 2, 2, 2, 27, 119, 3, 2, 2, 2, 29, 122, 3, 2, 2, 2, 31, 125, 3, 2, 2, 2, 33, 129, 3, 2, 2, 2, 35, 131, 3, 2, 2, 2, 37, 138, 3, 2, 2, 2, 39, 140, 3, 2, 2, 2, 41, 148, 3, 2, 2, 2, 43, 154, 3, 2, 2, 2, 45, 167, 3, 2, 2, 2, 47, 181, 3, 2, 2, 2, 49, 183, 3, 2, 2, 2, 51, 185, 3, 2, 2, 2, 53, 187, 3, 2, 2, 2, 55, 189, 3, 2, 2, 2, 57, 191, 3, 2, 2, 2, 59, 193, 3, 2, 2, 2, 61, 196, 3, 2, 2, 2, 63, 199, 3, 2, 2, 2, 65, 202, 3, 2, 2, 2, 67, 205, 3, 2, 2, 2, 69, 208, 3, 2, 2, 2, 71, 211, 3, 2, 2, 2, 73, 214, 3, 2, 2, 2, 75, 227, 3, 2, 2, 2, 77, 238, 3, 2, 2, 2, 79, 241, 3, 2, 2, 2, 81, 245, 3, 2, 2, 2, 83, 253, 3, 2, 2, 2, 85, 263, 3, 2, 2, 2, 87, 265, 3, 2, 2, 2, 89, 90, 7, 42, 2, 2, 90, 4, 3, 2, 2, 2, 91, 92, 7, 43, 2, 2, 92, 6, 3, 2, 2, 2, 93, 94, 7, 46, 2, 2, 94, 8, 3, 2, 2, 2, 95, 96, 7, 93, 2, 2, 96, 10, 3, 2, 2, 2, 97, 98, 7, 95, 2, 2, 98, 12, 3, 2, 2, 2, 99, 100, 7, 107, 2, 2, 100, 101, 7, 104, 2, 2, 101, 14, 3, 2, 2, 2, 102, 103, 7, 103, 2, 2, 103, 104, 7, 110, 2, 2, 104, 105, 7, 117, 2, 2, 105, 106, 7, 103, 2, 2, 106, 16, 3, 2, 2, 2, 107, 108, 7, 110, 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 18, 3, 2, 2, 2, 111, 112, 7, 63, 2, 2, 112, 20, 3, 2, 2, 2, 113, 114, 7, 61, 2, 2, 114, 22, 3, 2, 2, 2, 115, 116, 7, 125, 2, 2, 116, 24, 3, 2, 2, 2, 117, 118, 7, 127, 2, 2, 118, 26, 3, 2, 2, 2, 119, 120, 7, 104, 2, 2, 120, 121, 7, 112, 2, 2, 121, 28, 3, 2, 2, 2, 122, 123, 7, 47, 2, 2, 123, 124, 7, 64, 2, 2, 124, 30, 3, 2, 2, 2, 125, 126, 7, 102, 2, 2, 126, 127, 7, 103, 2, 2, 127, 128, 7, 104, 2, 2, 128, 32, 3, 2, 2, 2, 129, 130, 7, 60, 2, 2, 130, 34, 3, 2, 2, 2, 131, 132, 7, 86, 2, 2, 132, 133, 7, 103, 2, 2, 133, 134, 7, 112, 2, 2, 134, 135, 7, 117, 2, 2, 135, 136, 7, 113, 2, 2, 136, 137, 7, 116, 2, 2, 137, 36, 3, 2, 2, 2, 138, 139, 7, 97, 2, 2, 139, 38, 3, 2, 2, 2, 140, 141, 7, 120, 2, 2, 141, 142, 7, 50, 2, 2, 142, 143, 7, 48, 2, 2, 143, 144, 7, 50, 2, 2, 144, 145, 7, 48, 2, 2, 145, 146, 7, 52, 2, 2, 146, 40, 3, 2, 2, 2, 147, 149, 9, 2, 2, 2, 148, 147, 3, 2, 2, 2, 149, 150, 3, 2, 2, 2, 150, 148, 3, 2, 2, 2, 150, 151, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 153, 8, 21, 2, 2, 153, 42, 3, 2, 2, 2, 154, 155, 7, 49, 2, 2, 155, 156, 7, 49, 2, 2, 156, 160, 3, 2, 2, 2, 157, 159, 11, 2, 2, 2, 158, 157, 3, 2, 2, 2, 159, 162, 3, 2, 2, 2, 160, 161, 3, 2, 2, 2, 160, 158, 3, 2, 2, 2, 161, 163, 3, 2, 2, 2, 162, 160, 3, 2, 2, 2, 163, 164, 7, 12, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 8, 22, 2, 2, 166, 44, 3, 2, 2, 2, 167, 168, 7, 49, 2, 2, 168, 169, 7, 44, 2, 2, 169, 173, 3, 2, 2, 2, 170, 172, 11, 2, 2, 2, 171, 170, 3, 2, 2, 2, 172, 175, 3, 2, 2, 2, 173, 174, 3, 2, 2, 2, 173, 171, 3, 2, 2, 2, 174, 176, 3, 2, 2, 2, 175, 173, 3, 2, 2, 2, 176, 177, 7, 44, 2, 2, 177, 178, 7, 49, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 8, 23, 2, 2, 180, 46, 3, 2, 2, 2, 181, 182, 7, 44, 2, 2, 182, 48, 3, 2, 2, 2, 183, 184, 7, 49, 2, 2, 184, 50, 3, 2, 2, 2, 185, 186, 7, 45, 2, 2, 186, 52, 3, 2, 2, 2, 187, 188, 7, 47, 2, 2, 188, 54, 3, 2, 2, 2, 189, 190, 7, 62, 2, 2, 190, 56, 3, 2, 2, 2, 191, 192, 7, 64, 2, 2, 192, 58, 3, 2, 2, 2, 193, 194, 7, 62, 2, 2, 194, 195, 7, 63, 2, 2, 195, 60, 3, 2, 2, 2, 196, 197, 7, 64, 2, 2, 197, 198, 7, 63, 2, 2, 198, 62, 3, 2, 2, 2, 199, 200, 7, 63, 2, 2, 200, 201, 7, 63, 2, 2, 201, 64, 3, 2, 2, 2, 202, 203, 7, 35, 2, 2, 203, 204, 7, 63, 2, 2, 204, 66, 3, 2, 2, 2, 205, 206, 7, 66, 2, 2, 206, 207, 5, 83, 42, 2, 207, 68, 3, 2, 2, 2, 208, 209, 7, 39, 2, 2, 209, 210, 5, 83, 42, 2, 210, 70, 3, 2, 2, 2, 211, 212, 7, 39, 2, 2, 212, 213, 5, 79, 40, 2, 213, 72, 3, 2, 2, 2, 214, 215, 7, 111, 2, 2, 215, 216, 7, 119, 2, 2, 216, 217, 7, 118, 2, 2, 217, 74, 3, 2, 2, 2, 218, 219, 7, 86, 2, 2, 219, 220, 7, 116, 2, 2, 220, 221, 7, 119, 2, 2, 221, 228, 7, 103, 2, 2, 222, 223, 7, 72, 2, 2, 223, 224, 7, 99, 2, 2, 224, 225, 7, 110, 2, 2, 225, 226, 7, 117, 2, 2, 226, 228, 7, 103, 2, 2, 227, 218, 3, 2, 2, 2, 227, 222, 3, 2, 2, 2, 228, 76, 3, 2, 2, 2, 229, 230, 5, 79, 40, 2, 230, 231, 7, 48, 2, 2, 231, 233, 5, 79, 40, 2, 232, 234, 5, 81, 41, 2, 233, 232, 3, 2, 2, 2, 233, 234, 3, 2, 2, 2, 234, 239, 3, 2, 2, 2, 235, 236, 5, 79, 40, 2, 236, 237, 5, 81, 41, 2, 237, 239, 3, 2, 2, 2, 238, 229, 3, 2, 2, 2, 238, 235, 3, 2, 2, 2, 239, 78, 3, 2, 2, 2, 240, 242, 5, 87, 44, 2, 241, 240, 3, 2, 2, 2, 242, 243, 3, 2, 2, 2, 243, 241, 3, 2, 2, 2, 243, 244, 3, 2, 2, 2, 244, 80, 3, 2, 2, 2, 245, 247, 9, 3, 2, 2, 246, 248, 9, 4, 2, 2, 247, 246, 3, 2, 2, 2, 247, 248, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 250, 5, 79, 40, 2, 250, 82, 3, 2, 2, 2, 251, 254, 7, 97, 2, 2, 252, 254, 5, 85, 43, 2, 253, 251, 3, 2, 2, 2, 253, 252, 3, 2, 2, 2, 254, 260, 3, 2, 2, 2, 255, 259, 7, 97, 2, 2, 256, 259, 5, 85, 43, 2, 257, 259, 5, 87, 44, 2, 258, 255, 3, 2, 2, 2, 258, 256, 3, 2, 2, 2, 258, 257, 3, 2, 2, 2, 259, 262, 3, 2, 2, 2, 260, 258, 3, 2, 2, 2, 260, 261, 3, 2, 2, 2, 261, 84, 3, 2, 2, 2, 262, 260, 3, 2, 2, 2, 263, 264, 9, 5, 2, 2, 264, 86, 3, 2, 2, 2, 265, 266, 9, 6, 2, 2, 266, 88, 3, 2, 2, 2, 14, 2, 150, 160, 173, 227, 233, 238, 243, 247, 253, 258, 260, 3, 8, 2, 2] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py2/RelayLexer.py b/python/tvm/relay/grammar/py2/RelayLexer.py new file mode 100644 index 000000000000..be87421c2da6 --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayLexer.py @@ -0,0 +1,209 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# encoding: utf-8 +from __future__ import print_function +from antlr4 import * +from io import StringIO +import sys + + + +def serializedATN(): + with StringIO() as buf: + buf.write(u"\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2") + buf.write(u"*\u010b\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4") + buf.write(u"\7\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r") + buf.write(u"\t\r\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22") + buf.write(u"\4\23\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4") + buf.write(u"\30\t\30\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35") + buf.write(u"\t\35\4\36\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4") + buf.write(u"$\t$\4%\t%\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t") + buf.write(u",\3\2\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7") + buf.write(u"\3\b\3\b\3\b\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13") + buf.write(u"\3\f\3\f\3\r\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3") + buf.write(u"\20\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22") + buf.write(u"\3\22\3\23\3\23\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3") + buf.write(u"\25\6\25\u0095\n\25\r\25\16\25\u0096\3\25\3\25\3\26\3") + buf.write(u"\26\3\26\3\26\7\26\u009f\n\26\f\26\16\26\u00a2\13\26") + buf.write(u"\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\7\27\u00ac\n") + buf.write(u"\27\f\27\16\27\u00af\13\27\3\27\3\27\3\27\3\27\3\27\3") + buf.write(u"\30\3\30\3\31\3\31\3\32\3\32\3\33\3\33\3\34\3\34\3\35") + buf.write(u"\3\35\3\36\3\36\3\36\3\37\3\37\3\37\3 \3 \3 \3!\3!\3") + buf.write(u"!\3\"\3\"\3\"\3#\3#\3#\3$\3$\3$\3%\3%\3%\3%\3&\3&\3&") + buf.write(u"\3&\3&\3&\3&\3&\3&\5&\u00e4\n&\3\'\3\'\3\'\3\'\5\'\u00ea") + buf.write(u"\n\'\3\'\3\'\3\'\5\'\u00ef\n\'\3(\6(\u00f2\n(\r(\16(") + buf.write(u"\u00f3\3)\3)\5)\u00f8\n)\3)\3)\3*\3*\5*\u00fe\n*\3*\3") + buf.write(u"*\3*\7*\u0103\n*\f*\16*\u0106\13*\3+\3+\3,\3,\4\u00a0") + buf.write(u"\u00ad\2-\3\3\5\4\7\5\t\6\13\7\r\b\17\t\21\n\23\13\25") + buf.write(u"\f\27\r\31\16\33\17\35\20\37\21!\22#\23%\24\'\25)\26") + buf.write(u"+\27-\30/\31\61\32\63\33\65\34\67\359\36;\37= ?!A\"C") + buf.write(u"#E$G%I&K\'M(O)Q\2S*U\2W\2\3\2\7\5\2\13\f\17\17\"\"\4") + buf.write(u"\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0113\2\3\3\2\2\2\2") + buf.write(u"\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3") + buf.write(u"\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3") + buf.write(u"\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3") + buf.write(u"\2\2\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2") + buf.write(u"\2\'\3\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2") + buf.write(u"\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2") + buf.write(u"\2\2\29\3\2\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3") + buf.write(u"\2\2\2\2C\3\2\2\2\2E\3\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2") + buf.write(u"K\3\2\2\2\2M\3\2\2\2\2O\3\2\2\2\2S\3\2\2\2\3Y\3\2\2\2") + buf.write(u"\5[\3\2\2\2\7]\3\2\2\2\t_\3\2\2\2\13a\3\2\2\2\rc\3\2") + buf.write(u"\2\2\17f\3\2\2\2\21k\3\2\2\2\23o\3\2\2\2\25q\3\2\2\2") + buf.write(u"\27s\3\2\2\2\31u\3\2\2\2\33w\3\2\2\2\35z\3\2\2\2\37}") + buf.write(u"\3\2\2\2!\u0081\3\2\2\2#\u0083\3\2\2\2%\u008a\3\2\2\2") + buf.write(u"\'\u008c\3\2\2\2)\u0094\3\2\2\2+\u009a\3\2\2\2-\u00a7") + buf.write(u"\3\2\2\2/\u00b5\3\2\2\2\61\u00b7\3\2\2\2\63\u00b9\3\2") + buf.write(u"\2\2\65\u00bb\3\2\2\2\67\u00bd\3\2\2\29\u00bf\3\2\2\2") + buf.write(u";\u00c1\3\2\2\2=\u00c4\3\2\2\2?\u00c7\3\2\2\2A\u00ca") + buf.write(u"\3\2\2\2C\u00cd\3\2\2\2E\u00d0\3\2\2\2G\u00d3\3\2\2\2") + buf.write(u"I\u00d6\3\2\2\2K\u00e3\3\2\2\2M\u00ee\3\2\2\2O\u00f1") + buf.write(u"\3\2\2\2Q\u00f5\3\2\2\2S\u00fd\3\2\2\2U\u0107\3\2\2\2") + buf.write(u"W\u0109\3\2\2\2YZ\7*\2\2Z\4\3\2\2\2[\\\7+\2\2\\\6\3\2") + buf.write(u"\2\2]^\7.\2\2^\b\3\2\2\2_`\7]\2\2`\n\3\2\2\2ab\7_\2\2") + buf.write(u"b\f\3\2\2\2cd\7k\2\2de\7h\2\2e\16\3\2\2\2fg\7g\2\2gh") + buf.write(u"\7n\2\2hi\7u\2\2ij\7g\2\2j\20\3\2\2\2kl\7n\2\2lm\7g\2") + buf.write(u"\2mn\7v\2\2n\22\3\2\2\2op\7?\2\2p\24\3\2\2\2qr\7=\2\2") + buf.write(u"r\26\3\2\2\2st\7}\2\2t\30\3\2\2\2uv\7\177\2\2v\32\3\2") + buf.write(u"\2\2wx\7h\2\2xy\7p\2\2y\34\3\2\2\2z{\7/\2\2{|\7@\2\2") + buf.write(u"|\36\3\2\2\2}~\7f\2\2~\177\7g\2\2\177\u0080\7h\2\2\u0080") + buf.write(u" \3\2\2\2\u0081\u0082\7<\2\2\u0082\"\3\2\2\2\u0083\u0084") + buf.write(u"\7V\2\2\u0084\u0085\7g\2\2\u0085\u0086\7p\2\2\u0086\u0087") + buf.write(u"\7u\2\2\u0087\u0088\7q\2\2\u0088\u0089\7t\2\2\u0089$") + buf.write(u"\3\2\2\2\u008a\u008b\7a\2\2\u008b&\3\2\2\2\u008c\u008d") + buf.write(u"\7x\2\2\u008d\u008e\7\62\2\2\u008e\u008f\7\60\2\2\u008f") + buf.write(u"\u0090\7\62\2\2\u0090\u0091\7\60\2\2\u0091\u0092\7\64") + buf.write(u"\2\2\u0092(\3\2\2\2\u0093\u0095\t\2\2\2\u0094\u0093\3") + buf.write(u"\2\2\2\u0095\u0096\3\2\2\2\u0096\u0094\3\2\2\2\u0096") + buf.write(u"\u0097\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u0099\b\25\2") + buf.write(u"\2\u0099*\3\2\2\2\u009a\u009b\7\61\2\2\u009b\u009c\7") + buf.write(u"\61\2\2\u009c\u00a0\3\2\2\2\u009d\u009f\13\2\2\2\u009e") + buf.write(u"\u009d\3\2\2\2\u009f\u00a2\3\2\2\2\u00a0\u00a1\3\2\2") + buf.write(u"\2\u00a0\u009e\3\2\2\2\u00a1\u00a3\3\2\2\2\u00a2\u00a0") + buf.write(u"\3\2\2\2\u00a3\u00a4\7\f\2\2\u00a4\u00a5\3\2\2\2\u00a5") + buf.write(u"\u00a6\b\26\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7\61\2\2\u00a8") + buf.write(u"\u00a9\7,\2\2\u00a9\u00ad\3\2\2\2\u00aa\u00ac\13\2\2") + buf.write(u"\2\u00ab\u00aa\3\2\2\2\u00ac\u00af\3\2\2\2\u00ad\u00ae") + buf.write(u"\3\2\2\2\u00ad\u00ab\3\2\2\2\u00ae\u00b0\3\2\2\2\u00af") + buf.write(u"\u00ad\3\2\2\2\u00b0\u00b1\7,\2\2\u00b1\u00b2\7\61\2") + buf.write(u"\2\u00b2\u00b3\3\2\2\2\u00b3\u00b4\b\27\2\2\u00b4.\3") + buf.write(u"\2\2\2\u00b5\u00b6\7,\2\2\u00b6\60\3\2\2\2\u00b7\u00b8") + buf.write(u"\7\61\2\2\u00b8\62\3\2\2\2\u00b9\u00ba\7-\2\2\u00ba\64") + buf.write(u"\3\2\2\2\u00bb\u00bc\7/\2\2\u00bc\66\3\2\2\2\u00bd\u00be") + buf.write(u"\7>\2\2\u00be8\3\2\2\2\u00bf\u00c0\7@\2\2\u00c0:\3\2") + buf.write(u"\2\2\u00c1\u00c2\7>\2\2\u00c2\u00c3\7?\2\2\u00c3<\3\2") + buf.write(u"\2\2\u00c4\u00c5\7@\2\2\u00c5\u00c6\7?\2\2\u00c6>\3\2") + buf.write(u"\2\2\u00c7\u00c8\7?\2\2\u00c8\u00c9\7?\2\2\u00c9@\3\2") + buf.write(u"\2\2\u00ca\u00cb\7#\2\2\u00cb\u00cc\7?\2\2\u00ccB\3\2") + buf.write(u"\2\2\u00cd\u00ce\7B\2\2\u00ce\u00cf\5S*\2\u00cfD\3\2") + buf.write(u"\2\2\u00d0\u00d1\7\'\2\2\u00d1\u00d2\5S*\2\u00d2F\3\2") + buf.write(u"\2\2\u00d3\u00d4\7\'\2\2\u00d4\u00d5\5O(\2\u00d5H\3\2") + buf.write(u"\2\2\u00d6\u00d7\7o\2\2\u00d7\u00d8\7w\2\2\u00d8\u00d9") + buf.write(u"\7v\2\2\u00d9J\3\2\2\2\u00da\u00db\7V\2\2\u00db\u00dc") + buf.write(u"\7t\2\2\u00dc\u00dd\7w\2\2\u00dd\u00e4\7g\2\2\u00de\u00df") + buf.write(u"\7H\2\2\u00df\u00e0\7c\2\2\u00e0\u00e1\7n\2\2\u00e1\u00e2") + buf.write(u"\7u\2\2\u00e2\u00e4\7g\2\2\u00e3\u00da\3\2\2\2\u00e3") + buf.write(u"\u00de\3\2\2\2\u00e4L\3\2\2\2\u00e5\u00e6\5O(\2\u00e6") + buf.write(u"\u00e7\7\60\2\2\u00e7\u00e9\5O(\2\u00e8\u00ea\5Q)\2\u00e9") + buf.write(u"\u00e8\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea\u00ef\3\2\2") + buf.write(u"\2\u00eb\u00ec\5O(\2\u00ec\u00ed\5Q)\2\u00ed\u00ef\3") + buf.write(u"\2\2\2\u00ee\u00e5\3\2\2\2\u00ee\u00eb\3\2\2\2\u00ef") + buf.write(u"N\3\2\2\2\u00f0\u00f2\5W,\2\u00f1\u00f0\3\2\2\2\u00f2") + buf.write(u"\u00f3\3\2\2\2\u00f3\u00f1\3\2\2\2\u00f3\u00f4\3\2\2") + buf.write(u"\2\u00f4P\3\2\2\2\u00f5\u00f7\t\3\2\2\u00f6\u00f8\t\4") + buf.write(u"\2\2\u00f7\u00f6\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") + buf.write(u"\3\2\2\2\u00f9\u00fa\5O(\2\u00faR\3\2\2\2\u00fb\u00fe") + buf.write(u"\7a\2\2\u00fc\u00fe\5U+\2\u00fd\u00fb\3\2\2\2\u00fd\u00fc") + buf.write(u"\3\2\2\2\u00fe\u0104\3\2\2\2\u00ff\u0103\7a\2\2\u0100") + buf.write(u"\u0103\5U+\2\u0101\u0103\5W,\2\u0102\u00ff\3\2\2\2\u0102") + buf.write(u"\u0100\3\2\2\2\u0102\u0101\3\2\2\2\u0103\u0106\3\2\2") + buf.write(u"\2\u0104\u0102\3\2\2\2\u0104\u0105\3\2\2\2\u0105T\3\2") + buf.write(u"\2\2\u0106\u0104\3\2\2\2\u0107\u0108\t\5\2\2\u0108V\3") + buf.write(u"\2\2\2\u0109\u010a\t\6\2\2\u010aX\3\2\2\2\16\2\u0096") + buf.write(u"\u00a0\u00ad\u00e3\u00e9\u00ee\u00f3\u00f7\u00fd\u0102") + buf.write(u"\u0104\3\b\2\2") + return buf.getvalue() + + +class RelayLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + SEMVER = 19 + WS = 20 + LINE_COMMENT = 21 + COMMENT = 22 + MUL = 23 + DIV = 24 + ADD = 25 + SUB = 26 + LT = 27 + GT = 28 + LE = 29 + GE = 30 + EQ = 31 + NE = 32 + GLOBAL_VAR = 33 + LOCAL_VAR = 34 + GRAPH_VAR = 35 + MUT = 36 + BOOL_LIT = 37 + FLOAT = 38 + NAT = 39 + CNAME = 40 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ u"DEFAULT_MODE" ] + + literalNames = [ u"", + u"'('", u"')'", u"','", u"'['", u"']'", u"'if'", u"'else'", + u"'let'", u"'='", u"';'", u"'{'", u"'}'", u"'fn'", u"'->'", + u"'def'", u"':'", u"'Tensor'", u"'_'", u"'v0.0.2'", u"'*'", + u"'/'", u"'+'", u"'-'", u"'<'", u"'>'", u"'<='", u"'>='", u"'=='", + u"'!='", u"'mut'" ] + + symbolicNames = [ u"", + u"SEMVER", u"WS", u"LINE_COMMENT", u"COMMENT", u"MUL", u"DIV", + u"ADD", u"SUB", u"LT", u"GT", u"LE", u"GE", u"EQ", u"NE", u"GLOBAL_VAR", + u"LOCAL_VAR", u"GRAPH_VAR", u"MUT", u"BOOL_LIT", u"FLOAT", u"NAT", + u"CNAME" ] + + ruleNames = [ u"T__0", u"T__1", u"T__2", u"T__3", u"T__4", u"T__5", + u"T__6", u"T__7", u"T__8", u"T__9", u"T__10", u"T__11", + u"T__12", u"T__13", u"T__14", u"T__15", u"T__16", u"T__17", + u"SEMVER", u"WS", u"LINE_COMMENT", u"COMMENT", u"MUL", + u"DIV", u"ADD", u"SUB", u"LT", u"GT", u"LE", u"GE", u"EQ", + u"NE", u"GLOBAL_VAR", u"LOCAL_VAR", u"GRAPH_VAR", u"MUT", + u"BOOL_LIT", u"FLOAT", u"NAT", u"EXP", u"CNAME", u"LETTER", + u"DIGIT" ] + + grammarFileName = u"Relay.g4" + + def __init__(self, input=None, output=sys.stdout): + super(RelayLexer, self).__init__(input, output=output) + self.checkVersion("4.7.2") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/python/tvm/relay/grammar/py2/RelayLexer.tokens b/python/tvm/relay/grammar/py2/RelayLexer.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayLexer.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py2/RelayParser.py b/python/tvm/relay/grammar/py2/RelayParser.py new file mode 100644 index 000000000000..77f56bf0545a --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayParser.py @@ -0,0 +1,2311 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# encoding: utf-8 +from __future__ import print_function +from antlr4 import * +from io import StringIO +import sys + + +def serializedATN(): + with StringIO() as buf: + buf.write(u"\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3") + buf.write(u"*\u014c\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t") + buf.write(u"\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") + buf.write(u"\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4") + buf.write(u"\23\t\23\3\2\3\2\3\3\3\3\7\3+\n\3\f\3\16\3.\13\3\3\3") + buf.write(u"\5\3\61\n\3\3\3\3\3\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\6\4H\n\4\r") + buf.write(u"\4\16\4I\3\4\3\4\3\4\3\4\3\4\3\4\7\4R\n\4\f\4\16\4U\13") + buf.write(u"\4\5\4W\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\5\4d\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4n\n\4") + buf.write(u"\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write(u"\3\4\3\4\3\4\5\4\u0080\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\7\4\u0096\n\4\f\4\16\4\u0099\13\4\5\4\u009b\n\4\3") + buf.write(u"\4\7\4\u009e\n\4\f\4\16\4\u00a1\13\4\3\5\3\5\5\5\u00a5") + buf.write(u"\n\5\3\5\3\5\3\5\3\5\3\5\5\5\u00ac\n\5\3\5\3\5\3\6\3") + buf.write(u"\6\3\6\5\6\u00b3\n\6\3\6\3\6\3\6\3\6\3\6\5\6\u00ba\n") + buf.write(u"\6\3\6\3\6\3\7\3\7\3\7\3\7\3\7\3\7\5\7\u00c4\n\7\3\b") + buf.write(u"\3\b\3\b\7\b\u00c9\n\b\f\b\16\b\u00cc\13\b\5\b\u00ce") + buf.write(u"\n\b\3\t\3\t\3\t\5\t\u00d3\n\t\3\n\3\n\3\n\7\n\u00d8") + buf.write(u"\n\n\f\n\16\n\u00db\13\n\5\n\u00dd\n\n\3\13\3\13\3\13") + buf.write(u"\3\13\3\f\3\f\3\f\3\f\3\f\3\f\7\f\u00e9\n\f\f\f\16\f") + buf.write(u"\u00ec\13\f\3\f\3\f\5\f\u00f0\n\f\3\r\3\r\3\r\3\r\3\r") + buf.write(u"\3\r\3\r\3\r\3\r\3\r\3\r\6\r\u00fd\n\r\r\r\16\r\u00fe") + buf.write(u"\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\5\r") + buf.write(u"\u010d\n\r\3\r\3\r\3\r\3\r\7\r\u0113\n\r\f\r\16\r\u0116") + buf.write(u"\13\r\5\r\u0118\n\r\3\r\3\r\3\r\3\r\3\r\5\r\u011f\n\r") + buf.write(u"\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3") + buf.write(u"\16\6\16\u012c\n\16\r\16\16\16\u012d\3\16\3\16\5\16\u0132") + buf.write(u"\n\16\3\17\3\17\3\17\3\17\3\17\5\17\u0139\n\17\3\20\3") + buf.write(u"\20\3\21\3\21\3\21\3\21\3\22\3\22\3\22\5\22\u0144\n\22") + buf.write(u"\3\23\3\23\3\23\3\23\5\23\u014a\n\23\3\23\2\3\6\24\2") + buf.write(u"\4\6\b\n\f\16\20\22\24\26\30\32\34\36 \"$\2\6\3\2\31") + buf.write(u"\32\3\2\33\34\3\2\35 \3\2!\"\2\u0175\2&\3\2\2\2\4(\3") + buf.write(u"\2\2\2\6\177\3\2\2\2\b\u00a2\3\2\2\2\n\u00af\3\2\2\2") + buf.write(u"\f\u00c3\3\2\2\2\16\u00cd\3\2\2\2\20\u00cf\3\2\2\2\22") + buf.write(u"\u00dc\3\2\2\2\24\u00de\3\2\2\2\26\u00ef\3\2\2\2\30\u011e") + buf.write(u"\3\2\2\2\32\u0131\3\2\2\2\34\u0138\3\2\2\2\36\u013a\3") + buf.write(u"\2\2\2 \u013c\3\2\2\2\"\u0143\3\2\2\2$\u0149\3\2\2\2") + buf.write(u"&\'\7*\2\2\'\3\3\2\2\2(\60\7\25\2\2)+\5\n\6\2*)\3\2\2") + buf.write(u"\2+.\3\2\2\2,*\3\2\2\2,-\3\2\2\2-\61\3\2\2\2.,\3\2\2") + buf.write(u"\2/\61\5\6\4\2\60,\3\2\2\2\60/\3\2\2\2\61\62\3\2\2\2") + buf.write(u"\62\63\7\2\2\3\63\5\3\2\2\2\64\65\b\4\1\2\65\66\7\3\2") + buf.write(u"\2\66\67\5\6\4\2\678\7\4\2\28\u0080\3\2\2\29:\7\34\2") + buf.write(u"\2:\u0080\5\6\4\23;\u0080\5\b\5\2<=\7\3\2\2=\u0080\7") + buf.write(u"\4\2\2>?\7\3\2\2?@\5\6\4\2@A\7\5\2\2AB\7\4\2\2B\u0080") + buf.write(u"\3\2\2\2CD\7\3\2\2DG\5\6\4\2EF\7\5\2\2FH\5\6\4\2GE\3") + buf.write(u"\2\2\2HI\3\2\2\2IG\3\2\2\2IJ\3\2\2\2JK\3\2\2\2KL\7\4") + buf.write(u"\2\2L\u0080\3\2\2\2MV\7\6\2\2NS\5\6\4\2OP\7\5\2\2PR\5") + buf.write(u"\6\4\2QO\3\2\2\2RU\3\2\2\2SQ\3\2\2\2ST\3\2\2\2TW\3\2") + buf.write(u"\2\2US\3\2\2\2VN\3\2\2\2VW\3\2\2\2WX\3\2\2\2X\u0080\7") + buf.write(u"\7\2\2YZ\7\b\2\2Z[\7\3\2\2[\\\5\6\4\2\\]\7\4\2\2]^\5") + buf.write(u" \21\2^_\7\t\2\2_`\5 \21\2`\u0080\3\2\2\2ac\7\n\2\2b") + buf.write(u"d\7&\2\2cb\3\2\2\2cd\3\2\2\2de\3\2\2\2ef\5\20\t\2fg\7") + buf.write(u"\13\2\2gh\5\6\4\2hi\7\f\2\2ij\5\6\4\bj\u0080\3\2\2\2") + buf.write(u"km\7\n\2\2ln\7&\2\2ml\3\2\2\2mn\3\2\2\2no\3\2\2\2op\5") + buf.write(u"\20\t\2pq\7\13\2\2qr\7\r\2\2rs\5\6\4\2st\7\16\2\2tu\7") + buf.write(u"\f\2\2uv\5\6\4\7v\u0080\3\2\2\2wx\5$\23\2xy\7\13\2\2") + buf.write(u"yz\5\6\4\2z{\7\f\2\2{|\5\6\4\5|\u0080\3\2\2\2}\u0080") + buf.write(u"\5$\23\2~\u0080\5\"\22\2\177\64\3\2\2\2\1779\3\2\2\2") + buf.write(u"\177;\3\2\2\2\177<\3\2\2\2\177>\3\2\2\2\177C\3\2\2\2") + buf.write(u"\177M\3\2\2\2\177Y\3\2\2\2\177a\3\2\2\2\177k\3\2\2\2") + buf.write(u"\177w\3\2\2\2\177}\3\2\2\2\177~\3\2\2\2\u0080\u009f\3") + buf.write(u"\2\2\2\u0081\u0082\f\22\2\2\u0082\u0083\t\2\2\2\u0083") + buf.write(u"\u009e\5\6\4\23\u0084\u0085\f\21\2\2\u0085\u0086\t\3") + buf.write(u"\2\2\u0086\u009e\5\6\4\22\u0087\u0088\f\20\2\2\u0088") + buf.write(u"\u0089\t\4\2\2\u0089\u009e\5\6\4\21\u008a\u008b\f\17") + buf.write(u"\2\2\u008b\u008c\t\5\2\2\u008c\u009e\5\6\4\20\u008d\u008e") + buf.write(u"\f\6\2\2\u008e\u008f\7\f\2\2\u008f\u009e\5\6\4\7\u0090") + buf.write(u"\u0091\f\24\2\2\u0091\u009a\7\3\2\2\u0092\u0097\5\6\4") + buf.write(u"\2\u0093\u0094\7\5\2\2\u0094\u0096\5\6\4\2\u0095\u0093") + buf.write(u"\3\2\2\2\u0096\u0099\3\2\2\2\u0097\u0095\3\2\2\2\u0097") + buf.write(u"\u0098\3\2\2\2\u0098\u009b\3\2\2\2\u0099\u0097\3\2\2") + buf.write(u"\2\u009a\u0092\3\2\2\2\u009a\u009b\3\2\2\2\u009b\u009c") + buf.write(u"\3\2\2\2\u009c\u009e\7\4\2\2\u009d\u0081\3\2\2\2\u009d") + buf.write(u"\u0084\3\2\2\2\u009d\u0087\3\2\2\2\u009d\u008a\3\2\2") + buf.write(u"\2\u009d\u008d\3\2\2\2\u009d\u0090\3\2\2\2\u009e\u00a1") + buf.write(u"\3\2\2\2\u009f\u009d\3\2\2\2\u009f\u00a0\3\2\2\2\u00a0") + buf.write(u"\7\3\2\2\2\u00a1\u009f\3\2\2\2\u00a2\u00a4\7\17\2\2\u00a3") + buf.write(u"\u00a5\5\26\f\2\u00a4\u00a3\3\2\2\2\u00a4\u00a5\3\2\2") + buf.write(u"\2\u00a5\u00a6\3\2\2\2\u00a6\u00a7\7\3\2\2\u00a7\u00a8") + buf.write(u"\5\f\7\2\u00a8\u00ab\7\4\2\2\u00a9\u00aa\7\20\2\2\u00aa") + buf.write(u"\u00ac\5\30\r\2\u00ab\u00a9\3\2\2\2\u00ab\u00ac\3\2\2") + buf.write(u"\2\u00ac\u00ad\3\2\2\2\u00ad\u00ae\5 \21\2\u00ae\t\3") + buf.write(u"\2\2\2\u00af\u00b0\7\21\2\2\u00b0\u00b2\5$\23\2\u00b1") + buf.write(u"\u00b3\5\26\f\2\u00b2\u00b1\3\2\2\2\u00b2\u00b3\3\2\2") + buf.write(u"\2\u00b3\u00b4\3\2\2\2\u00b4\u00b5\7\3\2\2\u00b5\u00b6") + buf.write(u"\5\f\7\2\u00b6\u00b9\7\4\2\2\u00b7\u00b8\7\20\2\2\u00b8") + buf.write(u"\u00ba\5\30\r\2\u00b9\u00b7\3\2\2\2\u00b9\u00ba\3\2\2") + buf.write(u"\2\u00ba\u00bb\3\2\2\2\u00bb\u00bc\5 \21\2\u00bc\13\3") + buf.write(u"\2\2\2\u00bd\u00c4\5\16\b\2\u00be\u00c4\5\22\n\2\u00bf") + buf.write(u"\u00c0\5\16\b\2\u00c0\u00c1\7\5\2\2\u00c1\u00c2\5\22") + buf.write(u"\n\2\u00c2\u00c4\3\2\2\2\u00c3\u00bd\3\2\2\2\u00c3\u00be") + buf.write(u"\3\2\2\2\u00c3\u00bf\3\2\2\2\u00c4\r\3\2\2\2\u00c5\u00ca") + buf.write(u"\5\20\t\2\u00c6\u00c7\7\5\2\2\u00c7\u00c9\5\20\t\2\u00c8") + buf.write(u"\u00c6\3\2\2\2\u00c9\u00cc\3\2\2\2\u00ca\u00c8\3\2\2") + buf.write(u"\2\u00ca\u00cb\3\2\2\2\u00cb\u00ce\3\2\2\2\u00cc\u00ca") + buf.write(u"\3\2\2\2\u00cd\u00c5\3\2\2\2\u00cd\u00ce\3\2\2\2\u00ce") + buf.write(u"\17\3\2\2\2\u00cf\u00d2\5$\23\2\u00d0\u00d1\7\22\2\2") + buf.write(u"\u00d1\u00d3\5\30\r\2\u00d2\u00d0\3\2\2\2\u00d2\u00d3") + buf.write(u"\3\2\2\2\u00d3\21\3\2\2\2\u00d4\u00d9\5\24\13\2\u00d5") + buf.write(u"\u00d6\7\5\2\2\u00d6\u00d8\5\24\13\2\u00d7\u00d5\3\2") + buf.write(u"\2\2\u00d8\u00db\3\2\2\2\u00d9\u00d7\3\2\2\2\u00d9\u00da") + buf.write(u"\3\2\2\2\u00da\u00dd\3\2\2\2\u00db\u00d9\3\2\2\2\u00dc") + buf.write(u"\u00d4\3\2\2\2\u00dc\u00dd\3\2\2\2\u00dd\23\3\2\2\2\u00de") + buf.write(u"\u00df\7*\2\2\u00df\u00e0\7\13\2\2\u00e0\u00e1\5\6\4") + buf.write(u"\2\u00e1\25\3\2\2\2\u00e2\u00e3\7\6\2\2\u00e3\u00f0\7") + buf.write(u"\7\2\2\u00e4\u00e5\7\6\2\2\u00e5\u00ea\5$\23\2\u00e6") + buf.write(u"\u00e7\7\5\2\2\u00e7\u00e9\5$\23\2\u00e8\u00e6\3\2\2") + buf.write(u"\2\u00e9\u00ec\3\2\2\2\u00ea\u00e8\3\2\2\2\u00ea\u00eb") + buf.write(u"\3\2\2\2\u00eb\u00ed\3\2\2\2\u00ec\u00ea\3\2\2\2\u00ed") + buf.write(u"\u00ee\7\7\2\2\u00ee\u00f0\3\2\2\2\u00ef\u00e2\3\2\2") + buf.write(u"\2\u00ef\u00e4\3\2\2\2\u00f0\27\3\2\2\2\u00f1\u00f2\7") + buf.write(u"\3\2\2\u00f2\u011f\7\4\2\2\u00f3\u00f4\7\3\2\2\u00f4") + buf.write(u"\u00f5\5\30\r\2\u00f5\u00f6\7\5\2\2\u00f6\u00f7\7\4\2") + buf.write(u"\2\u00f7\u011f\3\2\2\2\u00f8\u00f9\7\3\2\2\u00f9\u00fc") + buf.write(u"\5\30\r\2\u00fa\u00fb\7\5\2\2\u00fb\u00fd\5\30\r\2\u00fc") + buf.write(u"\u00fa\3\2\2\2\u00fd\u00fe\3\2\2\2\u00fe\u00fc\3\2\2") + buf.write(u"\2\u00fe\u00ff\3\2\2\2\u00ff\u0100\3\2\2\2\u0100\u0101") + buf.write(u"\7\4\2\2\u0101\u011f\3\2\2\2\u0102\u011f\5\36\20\2\u0103") + buf.write(u"\u0104\7\23\2\2\u0104\u0105\7\6\2\2\u0105\u0106\5\32") + buf.write(u"\16\2\u0106\u0107\7\5\2\2\u0107\u0108\5\30\r\2\u0108") + buf.write(u"\u0109\7\7\2\2\u0109\u011f\3\2\2\2\u010a\u010c\7\17\2") + buf.write(u"\2\u010b\u010d\5\26\f\2\u010c\u010b\3\2\2\2\u010c\u010d") + buf.write(u"\3\2\2\2\u010d\u010e\3\2\2\2\u010e\u0117\7\3\2\2\u010f") + buf.write(u"\u0114\5\30\r\2\u0110\u0111\7\5\2\2\u0111\u0113\5\30") + buf.write(u"\r\2\u0112\u0110\3\2\2\2\u0113\u0116\3\2\2\2\u0114\u0112") + buf.write(u"\3\2\2\2\u0114\u0115\3\2\2\2\u0115\u0118\3\2\2\2\u0116") + buf.write(u"\u0114\3\2\2\2\u0117\u010f\3\2\2\2\u0117\u0118\3\2\2") + buf.write(u"\2\u0118\u0119\3\2\2\2\u0119\u011a\7\4\2\2\u011a\u011b") + buf.write(u"\7\20\2\2\u011b\u011f\5\30\r\2\u011c\u011f\7\24\2\2\u011d") + buf.write(u"\u011f\7)\2\2\u011e\u00f1\3\2\2\2\u011e\u00f3\3\2\2\2") + buf.write(u"\u011e\u00f8\3\2\2\2\u011e\u0102\3\2\2\2\u011e\u0103") + buf.write(u"\3\2\2\2\u011e\u010a\3\2\2\2\u011e\u011c\3\2\2\2\u011e") + buf.write(u"\u011d\3\2\2\2\u011f\31\3\2\2\2\u0120\u0121\7\3\2\2\u0121") + buf.write(u"\u0132\7\4\2\2\u0122\u0123\7\3\2\2\u0123\u0124\5\34\17") + buf.write(u"\2\u0124\u0125\7\5\2\2\u0125\u0126\7\4\2\2\u0126\u0132") + buf.write(u"\3\2\2\2\u0127\u0128\7\3\2\2\u0128\u012b\5\34\17\2\u0129") + buf.write(u"\u012a\7\5\2\2\u012a\u012c\5\34\17\2\u012b\u0129\3\2") + buf.write(u"\2\2\u012c\u012d\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e") + buf.write(u"\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0130\7\4\2\2\u0130") + buf.write(u"\u0132\3\2\2\2\u0131\u0120\3\2\2\2\u0131\u0122\3\2\2") + buf.write(u"\2\u0131\u0127\3\2\2\2\u0132\33\3\2\2\2\u0133\u0134\7") + buf.write(u"\3\2\2\u0134\u0135\5\34\17\2\u0135\u0136\7\4\2\2\u0136") + buf.write(u"\u0139\3\2\2\2\u0137\u0139\7)\2\2\u0138\u0133\3\2\2\2") + buf.write(u"\u0138\u0137\3\2\2\2\u0139\35\3\2\2\2\u013a\u013b\7*") + buf.write(u"\2\2\u013b\37\3\2\2\2\u013c\u013d\7\r\2\2\u013d\u013e") + buf.write(u"\5\6\4\2\u013e\u013f\7\16\2\2\u013f!\3\2\2\2\u0140\u0144") + buf.write(u"\7(\2\2\u0141\u0144\7)\2\2\u0142\u0144\7\'\2\2\u0143") + buf.write(u"\u0140\3\2\2\2\u0143\u0141\3\2\2\2\u0143\u0142\3\2\2") + buf.write(u"\2\u0144#\3\2\2\2\u0145\u014a\5\2\2\2\u0146\u014a\7#") + buf.write(u"\2\2\u0147\u014a\7$\2\2\u0148\u014a\7%\2\2\u0149\u0145") + buf.write(u"\3\2\2\2\u0149\u0146\3\2\2\2\u0149\u0147\3\2\2\2\u0149") + buf.write(u"\u0148\3\2\2\2\u014a%\3\2\2\2$,\60ISVcm\177\u0097\u009a") + buf.write(u"\u009d\u009f\u00a4\u00ab\u00b2\u00b9\u00c3\u00ca\u00cd") + buf.write(u"\u00d2\u00d9\u00dc\u00ea\u00ef\u00fe\u010c\u0114\u0117") + buf.write(u"\u011e\u012d\u0131\u0138\u0143\u0149") + return buf.getvalue() + + +class RelayParser ( Parser ): + + grammarFileName = "Relay.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ u"", u"'('", u"')'", u"','", u"'['", u"']'", + u"'if'", u"'else'", u"'let'", u"'='", u"';'", u"'{'", + u"'}'", u"'fn'", u"'->'", u"'def'", u"':'", u"'Tensor'", + u"'_'", u"'v0.0.2'", u"", u"", u"", + u"'*'", u"'/'", u"'+'", u"'-'", u"'<'", u"'>'", u"'<='", + u"'>='", u"'=='", u"'!='", u"", u"", + u"", u"'mut'" ] + + symbolicNames = [ u"", u"", u"", u"", + u"", u"", u"", u"", + u"", u"", u"", u"", + u"", u"", u"", u"", + u"", u"", u"", u"SEMVER", + u"WS", u"LINE_COMMENT", u"COMMENT", u"MUL", u"DIV", + u"ADD", u"SUB", u"LT", u"GT", u"LE", u"GE", u"EQ", + u"NE", u"GLOBAL_VAR", u"LOCAL_VAR", u"GRAPH_VAR", + u"MUT", u"BOOL_LIT", u"FLOAT", u"NAT", u"CNAME" ] + + RULE_opIdent = 0 + RULE_prog = 1 + RULE_expr = 2 + RULE_func = 3 + RULE_defn = 4 + RULE_argList = 5 + RULE_varList = 6 + RULE_var = 7 + RULE_attrList = 8 + RULE_attr = 9 + RULE_typeParamSeq = 10 + RULE_type_ = 11 + RULE_shapeSeq = 12 + RULE_shape = 13 + RULE_typeIdent = 14 + RULE_body = 15 + RULE_scalar = 16 + RULE_ident = 17 + + ruleNames = [ u"opIdent", u"prog", u"expr", u"func", u"defn", u"argList", + u"varList", u"var", u"attrList", u"attr", u"typeParamSeq", + u"type_", u"shapeSeq", u"shape", u"typeIdent", u"body", + u"scalar", u"ident" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + SEMVER=19 + WS=20 + LINE_COMMENT=21 + COMMENT=22 + MUL=23 + DIV=24 + ADD=25 + SUB=26 + LT=27 + GT=28 + LE=29 + GE=30 + EQ=31 + NE=32 + GLOBAL_VAR=33 + LOCAL_VAR=34 + GRAPH_VAR=35 + MUT=36 + BOOL_LIT=37 + FLOAT=38 + NAT=39 + CNAME=40 + + def __init__(self, input, output=sys.stdout): + super(RelayParser, self).__init__(input, output=output) + self.checkVersion("4.7.2") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class OpIdentContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.OpIdentContext, self).__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_opIdent + + def accept(self, visitor): + if hasattr(visitor, "visitOpIdent"): + return visitor.visitOpIdent(self) + else: + return visitor.visitChildren(self) + + + + + def opIdent(self): + + localctx = RelayParser.OpIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_opIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 36 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ProgContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ProgContext, self).__init__(parent, invokingState) + self.parser = parser + + def SEMVER(self): + return self.getToken(RelayParser.SEMVER, 0) + + def EOF(self): + return self.getToken(RelayParser.EOF, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def defn(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.DefnContext) + else: + return self.getTypedRuleContext(RelayParser.DefnContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_prog + + def accept(self, visitor): + if hasattr(visitor, "visitProg"): + return visitor.visitProg(self) + else: + return visitor.visitChildren(self) + + + + + def prog(self): + + localctx = RelayParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 38 + self.match(RelayParser.SEMVER) + self.state = 46 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.EOF, RelayParser.T__14]: + self.state = 42 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__14: + self.state = 39 + self.defn() + self.state = 44 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [RelayParser.T__0, RelayParser.T__3, RelayParser.T__5, RelayParser.T__7, RelayParser.T__12, RelayParser.SUB, RelayParser.GLOBAL_VAR, RelayParser.LOCAL_VAR, RelayParser.GRAPH_VAR, RelayParser.BOOL_LIT, RelayParser.FLOAT, RelayParser.NAT, RelayParser.CNAME]: + self.state = 45 + self.expr(0) + pass + else: + raise NoViableAltException(self) + + self.state = 48 + self.match(RelayParser.EOF) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ExprContext, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_expr + + + def copyFrom(self, ctx): + super(RelayParser.ExprContext, self).copyFrom(ctx) + + + class IdentExprContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.IdentExprContext, self).__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitIdentExpr"): + return visitor.visitIdentExpr(self) + else: + return visitor.visitChildren(self) + + + class CallContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.CallContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitCall"): + return visitor.visitCall(self) + else: + return visitor.visitChildren(self) + + + class NegContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.NegContext, self).__init__(parser) + self.copyFrom(ctx) + + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitNeg"): + return visitor.visitNeg(self) + else: + return visitor.visitChildren(self) + + + class TupleContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.TupleContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitTuple"): + return visitor.visitTuple(self) + else: + return visitor.visitChildren(self) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.ParensContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitParens"): + return visitor.visitParens(self) + else: + return visitor.visitChildren(self) + + + class FuncExprContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.FuncExprContext, self).__init__(parser) + self.copyFrom(ctx) + + def func(self): + return self.getTypedRuleContext(RelayParser.FuncContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitFuncExpr"): + return visitor.visitFuncExpr(self) + else: + return visitor.visitChildren(self) + + + class ScalarExprContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.ScalarExprContext, self).__init__(parser) + self.copyFrom(ctx) + + def scalar(self): + return self.getTypedRuleContext(RelayParser.ScalarContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitScalarExpr"): + return visitor.visitScalarExpr(self) + else: + return visitor.visitChildren(self) + + + class LetContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.LetContext, self).__init__(parser) + self.copyFrom(ctx) + + def var(self): + return self.getTypedRuleContext(RelayParser.VarContext,0) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUT(self): + return self.getToken(RelayParser.MUT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitLet"): + return visitor.visitLet(self) + else: + return visitor.visitChildren(self) + + + class TensorContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.TensorContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitTensor"): + return visitor.visitTensor(self) + else: + return visitor.visitChildren(self) + + + class IfElseContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.IfElseContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + def body(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.BodyContext) + else: + return self.getTypedRuleContext(RelayParser.BodyContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitIfElse"): + return visitor.visitIfElse(self) + else: + return visitor.visitChildren(self) + + + class GraphContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.GraphContext, self).__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitGraph"): + return visitor.visitGraph(self) + else: + return visitor.visitChildren(self) + + + class BinOpContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.BinOpContext, self).__init__(parser) + self.op = None # Token + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUL(self): + return self.getToken(RelayParser.MUL, 0) + def DIV(self): + return self.getToken(RelayParser.DIV, 0) + def ADD(self): + return self.getToken(RelayParser.ADD, 0) + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def LT(self): + return self.getToken(RelayParser.LT, 0) + def GT(self): + return self.getToken(RelayParser.GT, 0) + def LE(self): + return self.getToken(RelayParser.LE, 0) + def GE(self): + return self.getToken(RelayParser.GE, 0) + def EQ(self): + return self.getToken(RelayParser.EQ, 0) + def NE(self): + return self.getToken(RelayParser.NE, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitBinOp"): + return visitor.visitBinOp(self) + else: + return visitor.visitChildren(self) + + + + def expr(self, _p=0): + _parentctx = self._ctx + _parentState = self.state + localctx = RelayParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 4 + self.enterRecursionRule(localctx, 4, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 125 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + if la_ == 1: + localctx = RelayParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 51 + self.match(RelayParser.T__0) + self.state = 52 + self.expr(0) + self.state = 53 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.NegContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 55 + self.match(RelayParser.SUB) + self.state = 56 + self.expr(17) + pass + + elif la_ == 3: + localctx = RelayParser.FuncExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 57 + self.func() + pass + + elif la_ == 4: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 58 + self.match(RelayParser.T__0) + self.state = 59 + self.match(RelayParser.T__1) + pass + + elif la_ == 5: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 60 + self.match(RelayParser.T__0) + self.state = 61 + self.expr(0) + self.state = 62 + self.match(RelayParser.T__2) + self.state = 63 + self.match(RelayParser.T__1) + pass + + elif la_ == 6: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 65 + self.match(RelayParser.T__0) + self.state = 66 + self.expr(0) + self.state = 69 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 67 + self.match(RelayParser.T__2) + self.state = 68 + self.expr(0) + self.state = 71 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 73 + self.match(RelayParser.T__1) + pass + + elif la_ == 7: + localctx = RelayParser.TensorContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 75 + self.match(RelayParser.T__3) + self.state = 84 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 76 + self.expr(0) + self.state = 81 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 77 + self.match(RelayParser.T__2) + self.state = 78 + self.expr(0) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 86 + self.match(RelayParser.T__4) + pass + + elif la_ == 8: + localctx = RelayParser.IfElseContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 87 + self.match(RelayParser.T__5) + self.state = 88 + self.match(RelayParser.T__0) + self.state = 89 + self.expr(0) + self.state = 90 + self.match(RelayParser.T__1) + self.state = 91 + self.body() + self.state = 92 + self.match(RelayParser.T__6) + self.state = 93 + self.body() + pass + + elif la_ == 9: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 95 + self.match(RelayParser.T__7) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 96 + self.match(RelayParser.MUT) + + + self.state = 99 + self.var() + self.state = 100 + self.match(RelayParser.T__8) + self.state = 101 + self.expr(0) + self.state = 102 + self.match(RelayParser.T__9) + self.state = 103 + self.expr(6) + pass + + elif la_ == 10: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 105 + self.match(RelayParser.T__7) + self.state = 107 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 106 + self.match(RelayParser.MUT) + + + self.state = 109 + self.var() + self.state = 110 + self.match(RelayParser.T__8) + self.state = 111 + self.match(RelayParser.T__10) + self.state = 112 + self.expr(0) + self.state = 113 + self.match(RelayParser.T__11) + self.state = 114 + self.match(RelayParser.T__9) + self.state = 115 + self.expr(5) + pass + + elif la_ == 11: + localctx = RelayParser.GraphContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 117 + self.ident() + self.state = 118 + self.match(RelayParser.T__8) + self.state = 119 + self.expr(0) + self.state = 120 + self.match(RelayParser.T__9) + self.state = 121 + self.expr(3) + pass + + elif la_ == 12: + localctx = RelayParser.IdentExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 123 + self.ident() + pass + + elif la_ == 13: + localctx = RelayParser.ScalarExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 124 + self.scalar() + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 157 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 155 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,10,self._ctx) + if la_ == 1: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 127 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 128 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.MUL or _la==RelayParser.DIV): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 129 + self.expr(17) + pass + + elif la_ == 2: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 130 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 131 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.ADD or _la==RelayParser.SUB): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 132 + self.expr(16) + pass + + elif la_ == 3: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 133 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 134 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.expr(15) + pass + + elif la_ == 4: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 136 + if not self.precpred(self._ctx, 13): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 13)") + self.state = 137 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.EQ or _la==RelayParser.NE): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 138 + self.expr(14) + pass + + elif la_ == 5: + localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 139 + if not self.precpred(self._ctx, 4): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 140 + self.match(RelayParser.T__9) + self.state = 141 + self.expr(5) + pass + + elif la_ == 6: + localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 142 + if not self.precpred(self._ctx, 18): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") + self.state = 143 + self.match(RelayParser.T__0) + self.state = 152 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 144 + self.expr(0) + self.state = 149 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 145 + self.match(RelayParser.T__2) + self.state = 146 + self.expr(0) + self.state = 151 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 154 + self.match(RelayParser.T__1) + pass + + + self.state = 159 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class FuncContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.FuncContext, self).__init__(parent, invokingState) + self.parser = parser + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_func + + def accept(self, visitor): + if hasattr(visitor, "visitFunc"): + return visitor.visitFunc(self) + else: + return visitor.visitChildren(self) + + + + + def func(self): + + localctx = RelayParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_func) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 160 + self.match(RelayParser.T__12) + self.state = 162 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 161 + self.typeParamSeq() + + + self.state = 164 + self.match(RelayParser.T__0) + self.state = 165 + self.argList() + self.state = 166 + self.match(RelayParser.T__1) + self.state = 169 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 167 + self.match(RelayParser.T__13) + self.state = 168 + self.type_() + + + self.state = 171 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DefnContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.DefnContext, self).__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_defn + + def accept(self, visitor): + if hasattr(visitor, "visitDefn"): + return visitor.visitDefn(self) + else: + return visitor.visitChildren(self) + + + + + def defn(self): + + localctx = RelayParser.DefnContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_defn) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 173 + self.match(RelayParser.T__14) + self.state = 174 + self.ident() + self.state = 176 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 175 + self.typeParamSeq() + + + self.state = 178 + self.match(RelayParser.T__0) + self.state = 179 + self.argList() + self.state = 180 + self.match(RelayParser.T__1) + self.state = 183 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 181 + self.match(RelayParser.T__13) + self.state = 182 + self.type_() + + + self.state = 185 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgListContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ArgListContext, self).__init__(parent, invokingState) + self.parser = parser + + def varList(self): + return self.getTypedRuleContext(RelayParser.VarListContext,0) + + + def attrList(self): + return self.getTypedRuleContext(RelayParser.AttrListContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_argList + + def accept(self, visitor): + if hasattr(visitor, "visitArgList"): + return visitor.visitArgList(self) + else: + return visitor.visitChildren(self) + + + + + def argList(self): + + localctx = RelayParser.ArgListContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_argList) + try: + self.state = 193 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,16,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 187 + self.varList() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 188 + self.attrList() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 189 + self.varList() + self.state = 190 + self.match(RelayParser.T__2) + self.state = 191 + self.attrList() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarListContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.VarListContext, self).__init__(parent, invokingState) + self.parser = parser + + def var(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.VarContext) + else: + return self.getTypedRuleContext(RelayParser.VarContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_varList + + def accept(self, visitor): + if hasattr(visitor, "visitVarList"): + return visitor.visitVarList(self) + else: + return visitor.visitChildren(self) + + + + + def varList(self): + + localctx = RelayParser.VarListContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_varList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.CNAME))) != 0): + self.state = 195 + self.var() + self.state = 200 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 196 + self.match(RelayParser.T__2) + self.state = 197 + self.var() + self.state = 202 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.VarContext, self).__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_var + + def accept(self, visitor): + if hasattr(visitor, "visitVar"): + return visitor.visitVar(self) + else: + return visitor.visitChildren(self) + + + + + def var(self): + + localctx = RelayParser.VarContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_var) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 205 + self.ident() + self.state = 208 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__15: + self.state = 206 + self.match(RelayParser.T__15) + self.state = 207 + self.type_() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrListContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.AttrListContext, self).__init__(parent, invokingState) + self.parser = parser + + def attr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.AttrContext) + else: + return self.getTypedRuleContext(RelayParser.AttrContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_attrList + + def accept(self, visitor): + if hasattr(visitor, "visitAttrList"): + return visitor.visitAttrList(self) + else: + return visitor.visitChildren(self) + + + + + def attrList(self): + + localctx = RelayParser.AttrListContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_attrList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.CNAME: + self.state = 210 + self.attr() + self.state = 215 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 211 + self.match(RelayParser.T__2) + self.state = 212 + self.attr() + self.state = 217 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.AttrContext, self).__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_attr + + def accept(self, visitor): + if hasattr(visitor, "visitAttr"): + return visitor.visitAttr(self) + else: + return visitor.visitChildren(self) + + + + + def attr(self): + + localctx = RelayParser.AttrContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_attr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 220 + self.match(RelayParser.CNAME) + self.state = 221 + self.match(RelayParser.T__8) + self.state = 222 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeParamSeqContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.TypeParamSeqContext, self).__init__(parent, invokingState) + self.parser = parser + + def ident(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.IdentContext) + else: + return self.getTypedRuleContext(RelayParser.IdentContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_typeParamSeq + + def accept(self, visitor): + if hasattr(visitor, "visitTypeParamSeq"): + return visitor.visitTypeParamSeq(self) + else: + return visitor.visitChildren(self) + + + + + def typeParamSeq(self): + + localctx = RelayParser.TypeParamSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_typeParamSeq) + self._la = 0 # Token type + try: + self.state = 237 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,23,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 224 + self.match(RelayParser.T__3) + self.state = 225 + self.match(RelayParser.T__4) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 226 + self.match(RelayParser.T__3) + self.state = 227 + self.ident() + self.state = 232 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 228 + self.match(RelayParser.T__2) + self.state = 229 + self.ident() + self.state = 234 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 235 + self.match(RelayParser.T__4) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Type_Context(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.Type_Context, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_type_ + + + def copyFrom(self, ctx): + super(RelayParser.Type_Context, self).copyFrom(ctx) + + + + class IntTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.IntTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitIntType"): + return visitor.visitIntType(self) + else: + return visitor.visitChildren(self) + + + class TupleTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.TupleTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def type_(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitTupleType"): + return visitor.visitTupleType(self) + else: + return visitor.visitChildren(self) + + + class TypeIdentTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.TypeIdentTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def typeIdent(self): + return self.getTypedRuleContext(RelayParser.TypeIdentContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitTypeIdentType"): + return visitor.visitTypeIdentType(self) + else: + return visitor.visitChildren(self) + + + class IncompleteTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.IncompleteTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + + def accept(self, visitor): + if hasattr(visitor, "visitIncompleteType"): + return visitor.visitIncompleteType(self) + else: + return visitor.visitChildren(self) + + + class TensorTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.TensorTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def shapeSeq(self): + return self.getTypedRuleContext(RelayParser.ShapeSeqContext,0) + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitTensorType"): + return visitor.visitTensorType(self) + else: + return visitor.visitChildren(self) + + + class FuncTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.FuncTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def type_(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitFuncType"): + return visitor.visitFuncType(self) + else: + return visitor.visitChildren(self) + + + + def type_(self): + + localctx = RelayParser.Type_Context(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_type_) + self._la = 0 # Token type + try: + self.state = 284 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 239 + self.match(RelayParser.T__0) + self.state = 240 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 241 + self.match(RelayParser.T__0) + self.state = 242 + self.type_() + self.state = 243 + self.match(RelayParser.T__2) + self.state = 244 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.match(RelayParser.T__0) + self.state = 247 + self.type_() + self.state = 250 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 248 + self.match(RelayParser.T__2) + self.state = 249 + self.type_() + self.state = 252 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 254 + self.match(RelayParser.T__1) + pass + + elif la_ == 4: + localctx = RelayParser.TypeIdentTypeContext(self, localctx) + self.enterOuterAlt(localctx, 4) + self.state = 256 + self.typeIdent() + pass + + elif la_ == 5: + localctx = RelayParser.TensorTypeContext(self, localctx) + self.enterOuterAlt(localctx, 5) + self.state = 257 + self.match(RelayParser.T__16) + self.state = 258 + self.match(RelayParser.T__3) + self.state = 259 + self.shapeSeq() + self.state = 260 + self.match(RelayParser.T__2) + self.state = 261 + self.type_() + self.state = 262 + self.match(RelayParser.T__4) + pass + + elif la_ == 6: + localctx = RelayParser.FuncTypeContext(self, localctx) + self.enterOuterAlt(localctx, 6) + self.state = 264 + self.match(RelayParser.T__12) + self.state = 266 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 265 + self.typeParamSeq() + + + self.state = 268 + self.match(RelayParser.T__0) + self.state = 277 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__12) | (1 << RelayParser.T__16) | (1 << RelayParser.T__17) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 269 + self.type_() + self.state = 274 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 270 + self.match(RelayParser.T__2) + self.state = 271 + self.type_() + self.state = 276 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 279 + self.match(RelayParser.T__1) + self.state = 280 + self.match(RelayParser.T__13) + self.state = 281 + self.type_() + pass + + elif la_ == 7: + localctx = RelayParser.IncompleteTypeContext(self, localctx) + self.enterOuterAlt(localctx, 7) + self.state = 282 + self.match(RelayParser.T__17) + pass + + elif la_ == 8: + localctx = RelayParser.IntTypeContext(self, localctx) + self.enterOuterAlt(localctx, 8) + self.state = 283 + self.match(RelayParser.NAT) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeSeqContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ShapeSeqContext, self).__init__(parent, invokingState) + self.parser = parser + + def shape(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ShapeContext) + else: + return self.getTypedRuleContext(RelayParser.ShapeContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_shapeSeq + + def accept(self, visitor): + if hasattr(visitor, "visitShapeSeq"): + return visitor.visitShapeSeq(self) + else: + return visitor.visitChildren(self) + + + + + def shapeSeq(self): + + localctx = RelayParser.ShapeSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_shapeSeq) + self._la = 0 # Token type + try: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 286 + self.match(RelayParser.T__0) + self.state = 287 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 288 + self.match(RelayParser.T__0) + self.state = 289 + self.shape() + self.state = 290 + self.match(RelayParser.T__2) + self.state = 291 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 293 + self.match(RelayParser.T__0) + self.state = 294 + self.shape() + self.state = 297 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 295 + self.match(RelayParser.T__2) + self.state = 296 + self.shape() + self.state = 299 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 301 + self.match(RelayParser.T__1) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ShapeContext, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_shape + + + def copyFrom(self, ctx): + super(RelayParser.ShapeContext, self).copyFrom(ctx) + + + + class ParensShapeContext(ShapeContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ShapeContext) + super(RelayParser.ParensShapeContext, self).__init__(parser) + self.copyFrom(ctx) + + def shape(self): + return self.getTypedRuleContext(RelayParser.ShapeContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitParensShape"): + return visitor.visitParensShape(self) + else: + return visitor.visitChildren(self) + + + class IntShapeContext(ShapeContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ShapeContext) + super(RelayParser.IntShapeContext, self).__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitIntShape"): + return visitor.visitIntShape(self) + else: + return visitor.visitChildren(self) + + + + def shape(self): + + localctx = RelayParser.ShapeContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_shape) + try: + self.state = 310 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.T__0]: + localctx = RelayParser.ParensShapeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 305 + self.match(RelayParser.T__0) + self.state = 306 + self.shape() + self.state = 307 + self.match(RelayParser.T__1) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.IntShapeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 309 + self.match(RelayParser.NAT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeIdentContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.TypeIdentContext, self).__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_typeIdent + + def accept(self, visitor): + if hasattr(visitor, "visitTypeIdent"): + return visitor.visitTypeIdent(self) + else: + return visitor.visitChildren(self) + + + + + def typeIdent(self): + + localctx = RelayParser.TypeIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_typeIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 312 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BodyContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.BodyContext, self).__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_body + + def accept(self, visitor): + if hasattr(visitor, "visitBody"): + return visitor.visitBody(self) + else: + return visitor.visitChildren(self) + + + + + def body(self): + + localctx = RelayParser.BodyContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_body) + try: + self.enterOuterAlt(localctx, 1) + self.state = 314 + self.match(RelayParser.T__10) + self.state = 315 + self.expr(0) + self.state = 316 + self.match(RelayParser.T__11) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ScalarContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ScalarContext, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_scalar + + + def copyFrom(self, ctx): + super(RelayParser.ScalarContext, self).copyFrom(ctx) + + + + class ScalarFloatContext(ScalarContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ScalarContext) + super(RelayParser.ScalarFloatContext, self).__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(RelayParser.FLOAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitScalarFloat"): + return visitor.visitScalarFloat(self) + else: + return visitor.visitChildren(self) + + + class ScalarBoolContext(ScalarContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ScalarContext) + super(RelayParser.ScalarBoolContext, self).__init__(parser) + self.copyFrom(ctx) + + def BOOL_LIT(self): + return self.getToken(RelayParser.BOOL_LIT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitScalarBool"): + return visitor.visitScalarBool(self) + else: + return visitor.visitChildren(self) + + + class ScalarIntContext(ScalarContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ScalarContext) + super(RelayParser.ScalarIntContext, self).__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitScalarInt"): + return visitor.visitScalarInt(self) + else: + return visitor.visitChildren(self) + + + + def scalar(self): + + localctx = RelayParser.ScalarContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_scalar) + try: + self.state = 321 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.FLOAT]: + localctx = RelayParser.ScalarFloatContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.match(RelayParser.FLOAT) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.ScalarIntContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 319 + self.match(RelayParser.NAT) + pass + elif token in [RelayParser.BOOL_LIT]: + localctx = RelayParser.ScalarBoolContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 320 + self.match(RelayParser.BOOL_LIT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IdentContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.IdentContext, self).__init__(parent, invokingState) + self.parser = parser + + def opIdent(self): + return self.getTypedRuleContext(RelayParser.OpIdentContext,0) + + + def GLOBAL_VAR(self): + return self.getToken(RelayParser.GLOBAL_VAR, 0) + + def LOCAL_VAR(self): + return self.getToken(RelayParser.LOCAL_VAR, 0) + + def GRAPH_VAR(self): + return self.getToken(RelayParser.GRAPH_VAR, 0) + + def getRuleIndex(self): + return RelayParser.RULE_ident + + def accept(self, visitor): + if hasattr(visitor, "visitIdent"): + return visitor.visitIdent(self) + else: + return visitor.visitChildren(self) + + + + + def ident(self): + + localctx = RelayParser.IdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_ident) + try: + self.state = 327 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.CNAME]: + self.enterOuterAlt(localctx, 1) + self.state = 323 + self.opIdent() + pass + elif token in [RelayParser.GLOBAL_VAR]: + self.enterOuterAlt(localctx, 2) + self.state = 324 + self.match(RelayParser.GLOBAL_VAR) + pass + elif token in [RelayParser.LOCAL_VAR]: + self.enterOuterAlt(localctx, 3) + self.state = 325 + self.match(RelayParser.LOCAL_VAR) + pass + elif token in [RelayParser.GRAPH_VAR]: + self.enterOuterAlt(localctx, 4) + self.state = 326 + self.match(RelayParser.GRAPH_VAR) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx, ruleIndex, predIndex): + if self._predicates == None: + self._predicates = dict() + self._predicates[2] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx, predIndex): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 13) + + + if predIndex == 4: + return self.precpred(self._ctx, 4) + + + if predIndex == 5: + return self.precpred(self._ctx, 18) + + + + + diff --git a/python/tvm/relay/grammar/py2/RelayVisitor.py b/python/tvm/relay/grammar/py2/RelayVisitor.py new file mode 100644 index 000000000000..eae67d8cff58 --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayVisitor.py @@ -0,0 +1,192 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +from antlr4 import * + +# This class defines a complete generic visitor for a parse tree produced by RelayParser. + +class RelayVisitor(ParseTreeVisitor): + + # Visit a parse tree produced by RelayParser#opIdent. + def visitOpIdent(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#prog. + def visitProg(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#identExpr. + def visitIdentExpr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#call. + def visitCall(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#neg. + def visitNeg(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tuple. + def visitTuple(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parens. + def visitParens(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcExpr. + def visitFuncExpr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarExpr. + def visitScalarExpr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#let. + def visitLet(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensor. + def visitTensor(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ifElse. + def visitIfElse(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#graph. + def visitGraph(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#binOp. + def visitBinOp(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#func. + def visitFunc(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#defn. + def visitDefn(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#argList. + def visitArgList(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#varList. + def visitVarList(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#var. + def visitVar(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attrList. + def visitAttrList(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attr. + def visitAttr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeParamSeq. + def visitTypeParamSeq(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tupleType. + def visitTupleType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdentType. + def visitTypeIdentType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensorType. + def visitTensorType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcType. + def visitFuncType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#incompleteType. + def visitIncompleteType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intType. + def visitIntType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#shapeSeq. + def visitShapeSeq(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parensShape. + def visitParensShape(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intShape. + def visitIntShape(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdent. + def visitTypeIdent(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#body. + def visitBody(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarFloat. + def visitScalarFloat(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarInt. + def visitScalarInt(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarBool. + def visitScalarBool(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ident. + def visitIdent(self, ctx): + return self.visitChildren(ctx) + + diff --git a/python/tvm/relay/grammar/py3/.gitattributes b/python/tvm/relay/grammar/py3/.gitattributes new file mode 100644 index 000000000000..4adf65fa2f3c --- /dev/null +++ b/python/tvm/relay/grammar/py3/.gitattributes @@ -0,0 +1,3 @@ +Relay* binary +Relay* linguist-generated=true +Relay* linguist-detectable=false \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/.gitignore b/python/tvm/relay/grammar/py3/.gitignore deleted file mode 100644 index d677ff551940..000000000000 --- a/python/tvm/relay/grammar/py3/.gitignore +++ /dev/null @@ -1 +0,0 @@ -Relay* diff --git a/python/tvm/relay/grammar/py3/Relay.interp b/python/tvm/relay/grammar/py3/Relay.interp new file mode 100644 index 000000000000..c6893d096168 --- /dev/null +++ b/python/tvm/relay/grammar/py3/Relay.interp @@ -0,0 +1,109 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +opIdent +prog +expr +func +defn +argList +varList +var +attrList +attr +typeParamSeq +type_ +shapeSeq +shape +typeIdent +body +scalar +ident + + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 42, 332, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 3, 2, 3, 2, 3, 3, 3, 3, 7, 3, 43, 10, 3, 12, 3, 14, 3, 46, 11, 3, 3, 3, 5, 3, 49, 10, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 6, 4, 72, 10, 4, 13, 4, 14, 4, 73, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 82, 10, 4, 12, 4, 14, 4, 85, 11, 4, 5, 4, 87, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 100, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 110, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 128, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 150, 10, 4, 12, 4, 14, 4, 153, 11, 4, 5, 4, 155, 10, 4, 3, 4, 7, 4, 158, 10, 4, 12, 4, 14, 4, 161, 11, 4, 3, 5, 3, 5, 5, 5, 165, 10, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 5, 5, 172, 10, 5, 3, 5, 3, 5, 3, 6, 3, 6, 3, 6, 5, 6, 179, 10, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 5, 6, 186, 10, 6, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 5, 7, 196, 10, 7, 3, 8, 3, 8, 3, 8, 7, 8, 201, 10, 8, 12, 8, 14, 8, 204, 11, 8, 5, 8, 206, 10, 8, 3, 9, 3, 9, 3, 9, 5, 9, 211, 10, 9, 3, 10, 3, 10, 3, 10, 7, 10, 216, 10, 10, 12, 10, 14, 10, 219, 11, 10, 5, 10, 221, 10, 10, 3, 11, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 7, 12, 233, 10, 12, 12, 12, 14, 12, 236, 11, 12, 3, 12, 3, 12, 5, 12, 240, 10, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 6, 13, 253, 10, 13, 13, 13, 14, 13, 254, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 269, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 7, 13, 275, 10, 13, 12, 13, 14, 13, 278, 11, 13, 5, 13, 280, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 287, 10, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 6, 14, 300, 10, 14, 13, 14, 14, 14, 301, 3, 14, 3, 14, 5, 14, 306, 10, 14, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 5, 15, 313, 10, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 5, 18, 324, 10, 18, 3, 19, 3, 19, 3, 19, 3, 19, 5, 19, 330, 10, 19, 3, 19, 2, 3, 6, 20, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 2, 6, 3, 2, 25, 26, 3, 2, 27, 28, 3, 2, 29, 32, 3, 2, 33, 34, 2, 373, 2, 38, 3, 2, 2, 2, 4, 40, 3, 2, 2, 2, 6, 127, 3, 2, 2, 2, 8, 162, 3, 2, 2, 2, 10, 175, 3, 2, 2, 2, 12, 195, 3, 2, 2, 2, 14, 205, 3, 2, 2, 2, 16, 207, 3, 2, 2, 2, 18, 220, 3, 2, 2, 2, 20, 222, 3, 2, 2, 2, 22, 239, 3, 2, 2, 2, 24, 286, 3, 2, 2, 2, 26, 305, 3, 2, 2, 2, 28, 312, 3, 2, 2, 2, 30, 314, 3, 2, 2, 2, 32, 316, 3, 2, 2, 2, 34, 323, 3, 2, 2, 2, 36, 329, 3, 2, 2, 2, 38, 39, 7, 42, 2, 2, 39, 3, 3, 2, 2, 2, 40, 48, 7, 21, 2, 2, 41, 43, 5, 10, 6, 2, 42, 41, 3, 2, 2, 2, 43, 46, 3, 2, 2, 2, 44, 42, 3, 2, 2, 2, 44, 45, 3, 2, 2, 2, 45, 49, 3, 2, 2, 2, 46, 44, 3, 2, 2, 2, 47, 49, 5, 6, 4, 2, 48, 44, 3, 2, 2, 2, 48, 47, 3, 2, 2, 2, 49, 50, 3, 2, 2, 2, 50, 51, 7, 2, 2, 3, 51, 5, 3, 2, 2, 2, 52, 53, 8, 4, 1, 2, 53, 54, 7, 3, 2, 2, 54, 55, 5, 6, 4, 2, 55, 56, 7, 4, 2, 2, 56, 128, 3, 2, 2, 2, 57, 58, 7, 28, 2, 2, 58, 128, 5, 6, 4, 19, 59, 128, 5, 8, 5, 2, 60, 61, 7, 3, 2, 2, 61, 128, 7, 4, 2, 2, 62, 63, 7, 3, 2, 2, 63, 64, 5, 6, 4, 2, 64, 65, 7, 5, 2, 2, 65, 66, 7, 4, 2, 2, 66, 128, 3, 2, 2, 2, 67, 68, 7, 3, 2, 2, 68, 71, 5, 6, 4, 2, 69, 70, 7, 5, 2, 2, 70, 72, 5, 6, 4, 2, 71, 69, 3, 2, 2, 2, 72, 73, 3, 2, 2, 2, 73, 71, 3, 2, 2, 2, 73, 74, 3, 2, 2, 2, 74, 75, 3, 2, 2, 2, 75, 76, 7, 4, 2, 2, 76, 128, 3, 2, 2, 2, 77, 86, 7, 6, 2, 2, 78, 83, 5, 6, 4, 2, 79, 80, 7, 5, 2, 2, 80, 82, 5, 6, 4, 2, 81, 79, 3, 2, 2, 2, 82, 85, 3, 2, 2, 2, 83, 81, 3, 2, 2, 2, 83, 84, 3, 2, 2, 2, 84, 87, 3, 2, 2, 2, 85, 83, 3, 2, 2, 2, 86, 78, 3, 2, 2, 2, 86, 87, 3, 2, 2, 2, 87, 88, 3, 2, 2, 2, 88, 128, 7, 7, 2, 2, 89, 90, 7, 8, 2, 2, 90, 91, 7, 3, 2, 2, 91, 92, 5, 6, 4, 2, 92, 93, 7, 4, 2, 2, 93, 94, 5, 32, 17, 2, 94, 95, 7, 9, 2, 2, 95, 96, 5, 32, 17, 2, 96, 128, 3, 2, 2, 2, 97, 99, 7, 10, 2, 2, 98, 100, 7, 38, 2, 2, 99, 98, 3, 2, 2, 2, 99, 100, 3, 2, 2, 2, 100, 101, 3, 2, 2, 2, 101, 102, 5, 16, 9, 2, 102, 103, 7, 11, 2, 2, 103, 104, 5, 6, 4, 2, 104, 105, 7, 12, 2, 2, 105, 106, 5, 6, 4, 8, 106, 128, 3, 2, 2, 2, 107, 109, 7, 10, 2, 2, 108, 110, 7, 38, 2, 2, 109, 108, 3, 2, 2, 2, 109, 110, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, 112, 5, 16, 9, 2, 112, 113, 7, 11, 2, 2, 113, 114, 7, 13, 2, 2, 114, 115, 5, 6, 4, 2, 115, 116, 7, 14, 2, 2, 116, 117, 7, 12, 2, 2, 117, 118, 5, 6, 4, 7, 118, 128, 3, 2, 2, 2, 119, 120, 5, 36, 19, 2, 120, 121, 7, 11, 2, 2, 121, 122, 5, 6, 4, 2, 122, 123, 7, 12, 2, 2, 123, 124, 5, 6, 4, 5, 124, 128, 3, 2, 2, 2, 125, 128, 5, 36, 19, 2, 126, 128, 5, 34, 18, 2, 127, 52, 3, 2, 2, 2, 127, 57, 3, 2, 2, 2, 127, 59, 3, 2, 2, 2, 127, 60, 3, 2, 2, 2, 127, 62, 3, 2, 2, 2, 127, 67, 3, 2, 2, 2, 127, 77, 3, 2, 2, 2, 127, 89, 3, 2, 2, 2, 127, 97, 3, 2, 2, 2, 127, 107, 3, 2, 2, 2, 127, 119, 3, 2, 2, 2, 127, 125, 3, 2, 2, 2, 127, 126, 3, 2, 2, 2, 128, 159, 3, 2, 2, 2, 129, 130, 12, 18, 2, 2, 130, 131, 9, 2, 2, 2, 131, 158, 5, 6, 4, 19, 132, 133, 12, 17, 2, 2, 133, 134, 9, 3, 2, 2, 134, 158, 5, 6, 4, 18, 135, 136, 12, 16, 2, 2, 136, 137, 9, 4, 2, 2, 137, 158, 5, 6, 4, 17, 138, 139, 12, 15, 2, 2, 139, 140, 9, 5, 2, 2, 140, 158, 5, 6, 4, 16, 141, 142, 12, 6, 2, 2, 142, 143, 7, 12, 2, 2, 143, 158, 5, 6, 4, 7, 144, 145, 12, 20, 2, 2, 145, 154, 7, 3, 2, 2, 146, 151, 5, 6, 4, 2, 147, 148, 7, 5, 2, 2, 148, 150, 5, 6, 4, 2, 149, 147, 3, 2, 2, 2, 150, 153, 3, 2, 2, 2, 151, 149, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 155, 3, 2, 2, 2, 153, 151, 3, 2, 2, 2, 154, 146, 3, 2, 2, 2, 154, 155, 3, 2, 2, 2, 155, 156, 3, 2, 2, 2, 156, 158, 7, 4, 2, 2, 157, 129, 3, 2, 2, 2, 157, 132, 3, 2, 2, 2, 157, 135, 3, 2, 2, 2, 157, 138, 3, 2, 2, 2, 157, 141, 3, 2, 2, 2, 157, 144, 3, 2, 2, 2, 158, 161, 3, 2, 2, 2, 159, 157, 3, 2, 2, 2, 159, 160, 3, 2, 2, 2, 160, 7, 3, 2, 2, 2, 161, 159, 3, 2, 2, 2, 162, 164, 7, 15, 2, 2, 163, 165, 5, 22, 12, 2, 164, 163, 3, 2, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 3, 2, 2, 2, 166, 167, 7, 3, 2, 2, 167, 168, 5, 12, 7, 2, 168, 171, 7, 4, 2, 2, 169, 170, 7, 16, 2, 2, 170, 172, 5, 24, 13, 2, 171, 169, 3, 2, 2, 2, 171, 172, 3, 2, 2, 2, 172, 173, 3, 2, 2, 2, 173, 174, 5, 32, 17, 2, 174, 9, 3, 2, 2, 2, 175, 176, 7, 17, 2, 2, 176, 178, 5, 36, 19, 2, 177, 179, 5, 22, 12, 2, 178, 177, 3, 2, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 3, 2, 2, 2, 180, 181, 7, 3, 2, 2, 181, 182, 5, 12, 7, 2, 182, 185, 7, 4, 2, 2, 183, 184, 7, 16, 2, 2, 184, 186, 5, 24, 13, 2, 185, 183, 3, 2, 2, 2, 185, 186, 3, 2, 2, 2, 186, 187, 3, 2, 2, 2, 187, 188, 5, 32, 17, 2, 188, 11, 3, 2, 2, 2, 189, 196, 5, 14, 8, 2, 190, 196, 5, 18, 10, 2, 191, 192, 5, 14, 8, 2, 192, 193, 7, 5, 2, 2, 193, 194, 5, 18, 10, 2, 194, 196, 3, 2, 2, 2, 195, 189, 3, 2, 2, 2, 195, 190, 3, 2, 2, 2, 195, 191, 3, 2, 2, 2, 196, 13, 3, 2, 2, 2, 197, 202, 5, 16, 9, 2, 198, 199, 7, 5, 2, 2, 199, 201, 5, 16, 9, 2, 200, 198, 3, 2, 2, 2, 201, 204, 3, 2, 2, 2, 202, 200, 3, 2, 2, 2, 202, 203, 3, 2, 2, 2, 203, 206, 3, 2, 2, 2, 204, 202, 3, 2, 2, 2, 205, 197, 3, 2, 2, 2, 205, 206, 3, 2, 2, 2, 206, 15, 3, 2, 2, 2, 207, 210, 5, 36, 19, 2, 208, 209, 7, 18, 2, 2, 209, 211, 5, 24, 13, 2, 210, 208, 3, 2, 2, 2, 210, 211, 3, 2, 2, 2, 211, 17, 3, 2, 2, 2, 212, 217, 5, 20, 11, 2, 213, 214, 7, 5, 2, 2, 214, 216, 5, 20, 11, 2, 215, 213, 3, 2, 2, 2, 216, 219, 3, 2, 2, 2, 217, 215, 3, 2, 2, 2, 217, 218, 3, 2, 2, 2, 218, 221, 3, 2, 2, 2, 219, 217, 3, 2, 2, 2, 220, 212, 3, 2, 2, 2, 220, 221, 3, 2, 2, 2, 221, 19, 3, 2, 2, 2, 222, 223, 7, 42, 2, 2, 223, 224, 7, 11, 2, 2, 224, 225, 5, 6, 4, 2, 225, 21, 3, 2, 2, 2, 226, 227, 7, 6, 2, 2, 227, 240, 7, 7, 2, 2, 228, 229, 7, 6, 2, 2, 229, 234, 5, 36, 19, 2, 230, 231, 7, 5, 2, 2, 231, 233, 5, 36, 19, 2, 232, 230, 3, 2, 2, 2, 233, 236, 3, 2, 2, 2, 234, 232, 3, 2, 2, 2, 234, 235, 3, 2, 2, 2, 235, 237, 3, 2, 2, 2, 236, 234, 3, 2, 2, 2, 237, 238, 7, 7, 2, 2, 238, 240, 3, 2, 2, 2, 239, 226, 3, 2, 2, 2, 239, 228, 3, 2, 2, 2, 240, 23, 3, 2, 2, 2, 241, 242, 7, 3, 2, 2, 242, 287, 7, 4, 2, 2, 243, 244, 7, 3, 2, 2, 244, 245, 5, 24, 13, 2, 245, 246, 7, 5, 2, 2, 246, 247, 7, 4, 2, 2, 247, 287, 3, 2, 2, 2, 248, 249, 7, 3, 2, 2, 249, 252, 5, 24, 13, 2, 250, 251, 7, 5, 2, 2, 251, 253, 5, 24, 13, 2, 252, 250, 3, 2, 2, 2, 253, 254, 3, 2, 2, 2, 254, 252, 3, 2, 2, 2, 254, 255, 3, 2, 2, 2, 255, 256, 3, 2, 2, 2, 256, 257, 7, 4, 2, 2, 257, 287, 3, 2, 2, 2, 258, 287, 5, 30, 16, 2, 259, 260, 7, 19, 2, 2, 260, 261, 7, 6, 2, 2, 261, 262, 5, 26, 14, 2, 262, 263, 7, 5, 2, 2, 263, 264, 5, 24, 13, 2, 264, 265, 7, 7, 2, 2, 265, 287, 3, 2, 2, 2, 266, 268, 7, 15, 2, 2, 267, 269, 5, 22, 12, 2, 268, 267, 3, 2, 2, 2, 268, 269, 3, 2, 2, 2, 269, 270, 3, 2, 2, 2, 270, 279, 7, 3, 2, 2, 271, 276, 5, 24, 13, 2, 272, 273, 7, 5, 2, 2, 273, 275, 5, 24, 13, 2, 274, 272, 3, 2, 2, 2, 275, 278, 3, 2, 2, 2, 276, 274, 3, 2, 2, 2, 276, 277, 3, 2, 2, 2, 277, 280, 3, 2, 2, 2, 278, 276, 3, 2, 2, 2, 279, 271, 3, 2, 2, 2, 279, 280, 3, 2, 2, 2, 280, 281, 3, 2, 2, 2, 281, 282, 7, 4, 2, 2, 282, 283, 7, 16, 2, 2, 283, 287, 5, 24, 13, 2, 284, 287, 7, 20, 2, 2, 285, 287, 7, 41, 2, 2, 286, 241, 3, 2, 2, 2, 286, 243, 3, 2, 2, 2, 286, 248, 3, 2, 2, 2, 286, 258, 3, 2, 2, 2, 286, 259, 3, 2, 2, 2, 286, 266, 3, 2, 2, 2, 286, 284, 3, 2, 2, 2, 286, 285, 3, 2, 2, 2, 287, 25, 3, 2, 2, 2, 288, 289, 7, 3, 2, 2, 289, 306, 7, 4, 2, 2, 290, 291, 7, 3, 2, 2, 291, 292, 5, 28, 15, 2, 292, 293, 7, 5, 2, 2, 293, 294, 7, 4, 2, 2, 294, 306, 3, 2, 2, 2, 295, 296, 7, 3, 2, 2, 296, 299, 5, 28, 15, 2, 297, 298, 7, 5, 2, 2, 298, 300, 5, 28, 15, 2, 299, 297, 3, 2, 2, 2, 300, 301, 3, 2, 2, 2, 301, 299, 3, 2, 2, 2, 301, 302, 3, 2, 2, 2, 302, 303, 3, 2, 2, 2, 303, 304, 7, 4, 2, 2, 304, 306, 3, 2, 2, 2, 305, 288, 3, 2, 2, 2, 305, 290, 3, 2, 2, 2, 305, 295, 3, 2, 2, 2, 306, 27, 3, 2, 2, 2, 307, 308, 7, 3, 2, 2, 308, 309, 5, 28, 15, 2, 309, 310, 7, 4, 2, 2, 310, 313, 3, 2, 2, 2, 311, 313, 7, 41, 2, 2, 312, 307, 3, 2, 2, 2, 312, 311, 3, 2, 2, 2, 313, 29, 3, 2, 2, 2, 314, 315, 7, 42, 2, 2, 315, 31, 3, 2, 2, 2, 316, 317, 7, 13, 2, 2, 317, 318, 5, 6, 4, 2, 318, 319, 7, 14, 2, 2, 319, 33, 3, 2, 2, 2, 320, 324, 7, 40, 2, 2, 321, 324, 7, 41, 2, 2, 322, 324, 7, 39, 2, 2, 323, 320, 3, 2, 2, 2, 323, 321, 3, 2, 2, 2, 323, 322, 3, 2, 2, 2, 324, 35, 3, 2, 2, 2, 325, 330, 5, 2, 2, 2, 326, 330, 7, 35, 2, 2, 327, 330, 7, 36, 2, 2, 328, 330, 7, 37, 2, 2, 329, 325, 3, 2, 2, 2, 329, 326, 3, 2, 2, 2, 329, 327, 3, 2, 2, 2, 329, 328, 3, 2, 2, 2, 330, 37, 3, 2, 2, 2, 36, 44, 48, 73, 83, 86, 99, 109, 127, 151, 154, 157, 159, 164, 171, 178, 185, 195, 202, 205, 210, 217, 220, 234, 239, 254, 268, 276, 279, 286, 301, 305, 312, 323, 329] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/Relay.tokens b/python/tvm/relay/grammar/py3/Relay.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py3/Relay.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py3/RelayLexer.interp b/python/tvm/relay/grammar/py3/RelayLexer.interp new file mode 100644 index 000000000000..092b3589ab70 --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayLexer.interp @@ -0,0 +1,140 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +T__0 +T__1 +T__2 +T__3 +T__4 +T__5 +T__6 +T__7 +T__8 +T__9 +T__10 +T__11 +T__12 +T__13 +T__14 +T__15 +T__16 +T__17 +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +EXP +CNAME +LETTER +DIGIT + +channel names: +DEFAULT_TOKEN_CHANNEL +HIDDEN + +mode names: +DEFAULT_MODE + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 42, 267, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 11, 3, 11, 3, 12, 3, 12, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 6, 21, 149, 10, 21, 13, 21, 14, 21, 150, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 22, 7, 22, 159, 10, 22, 12, 22, 14, 22, 162, 11, 22, 3, 22, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 23, 3, 23, 7, 23, 172, 10, 23, 12, 23, 14, 23, 175, 11, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 34, 3, 34, 3, 34, 3, 35, 3, 35, 3, 35, 3, 36, 3, 36, 3, 36, 3, 37, 3, 37, 3, 37, 3, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 228, 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 234, 10, 39, 3, 39, 3, 39, 3, 39, 5, 39, 239, 10, 39, 3, 40, 6, 40, 242, 10, 40, 13, 40, 14, 40, 243, 3, 41, 3, 41, 5, 41, 248, 10, 41, 3, 41, 3, 41, 3, 42, 3, 42, 5, 42, 254, 10, 42, 3, 42, 3, 42, 3, 42, 7, 42, 259, 10, 42, 12, 42, 14, 42, 262, 11, 42, 3, 43, 3, 43, 3, 44, 3, 44, 4, 160, 173, 2, 45, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 2, 83, 42, 85, 2, 87, 2, 3, 2, 7, 5, 2, 11, 12, 15, 15, 34, 34, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 67, 92, 99, 124, 3, 2, 50, 59, 2, 275, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 83, 3, 2, 2, 2, 3, 89, 3, 2, 2, 2, 5, 91, 3, 2, 2, 2, 7, 93, 3, 2, 2, 2, 9, 95, 3, 2, 2, 2, 11, 97, 3, 2, 2, 2, 13, 99, 3, 2, 2, 2, 15, 102, 3, 2, 2, 2, 17, 107, 3, 2, 2, 2, 19, 111, 3, 2, 2, 2, 21, 113, 3, 2, 2, 2, 23, 115, 3, 2, 2, 2, 25, 117, 3, 2, 2, 2, 27, 119, 3, 2, 2, 2, 29, 122, 3, 2, 2, 2, 31, 125, 3, 2, 2, 2, 33, 129, 3, 2, 2, 2, 35, 131, 3, 2, 2, 2, 37, 138, 3, 2, 2, 2, 39, 140, 3, 2, 2, 2, 41, 148, 3, 2, 2, 2, 43, 154, 3, 2, 2, 2, 45, 167, 3, 2, 2, 2, 47, 181, 3, 2, 2, 2, 49, 183, 3, 2, 2, 2, 51, 185, 3, 2, 2, 2, 53, 187, 3, 2, 2, 2, 55, 189, 3, 2, 2, 2, 57, 191, 3, 2, 2, 2, 59, 193, 3, 2, 2, 2, 61, 196, 3, 2, 2, 2, 63, 199, 3, 2, 2, 2, 65, 202, 3, 2, 2, 2, 67, 205, 3, 2, 2, 2, 69, 208, 3, 2, 2, 2, 71, 211, 3, 2, 2, 2, 73, 214, 3, 2, 2, 2, 75, 227, 3, 2, 2, 2, 77, 238, 3, 2, 2, 2, 79, 241, 3, 2, 2, 2, 81, 245, 3, 2, 2, 2, 83, 253, 3, 2, 2, 2, 85, 263, 3, 2, 2, 2, 87, 265, 3, 2, 2, 2, 89, 90, 7, 42, 2, 2, 90, 4, 3, 2, 2, 2, 91, 92, 7, 43, 2, 2, 92, 6, 3, 2, 2, 2, 93, 94, 7, 46, 2, 2, 94, 8, 3, 2, 2, 2, 95, 96, 7, 93, 2, 2, 96, 10, 3, 2, 2, 2, 97, 98, 7, 95, 2, 2, 98, 12, 3, 2, 2, 2, 99, 100, 7, 107, 2, 2, 100, 101, 7, 104, 2, 2, 101, 14, 3, 2, 2, 2, 102, 103, 7, 103, 2, 2, 103, 104, 7, 110, 2, 2, 104, 105, 7, 117, 2, 2, 105, 106, 7, 103, 2, 2, 106, 16, 3, 2, 2, 2, 107, 108, 7, 110, 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 18, 3, 2, 2, 2, 111, 112, 7, 63, 2, 2, 112, 20, 3, 2, 2, 2, 113, 114, 7, 61, 2, 2, 114, 22, 3, 2, 2, 2, 115, 116, 7, 125, 2, 2, 116, 24, 3, 2, 2, 2, 117, 118, 7, 127, 2, 2, 118, 26, 3, 2, 2, 2, 119, 120, 7, 104, 2, 2, 120, 121, 7, 112, 2, 2, 121, 28, 3, 2, 2, 2, 122, 123, 7, 47, 2, 2, 123, 124, 7, 64, 2, 2, 124, 30, 3, 2, 2, 2, 125, 126, 7, 102, 2, 2, 126, 127, 7, 103, 2, 2, 127, 128, 7, 104, 2, 2, 128, 32, 3, 2, 2, 2, 129, 130, 7, 60, 2, 2, 130, 34, 3, 2, 2, 2, 131, 132, 7, 86, 2, 2, 132, 133, 7, 103, 2, 2, 133, 134, 7, 112, 2, 2, 134, 135, 7, 117, 2, 2, 135, 136, 7, 113, 2, 2, 136, 137, 7, 116, 2, 2, 137, 36, 3, 2, 2, 2, 138, 139, 7, 97, 2, 2, 139, 38, 3, 2, 2, 2, 140, 141, 7, 120, 2, 2, 141, 142, 7, 50, 2, 2, 142, 143, 7, 48, 2, 2, 143, 144, 7, 50, 2, 2, 144, 145, 7, 48, 2, 2, 145, 146, 7, 52, 2, 2, 146, 40, 3, 2, 2, 2, 147, 149, 9, 2, 2, 2, 148, 147, 3, 2, 2, 2, 149, 150, 3, 2, 2, 2, 150, 148, 3, 2, 2, 2, 150, 151, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 153, 8, 21, 2, 2, 153, 42, 3, 2, 2, 2, 154, 155, 7, 49, 2, 2, 155, 156, 7, 49, 2, 2, 156, 160, 3, 2, 2, 2, 157, 159, 11, 2, 2, 2, 158, 157, 3, 2, 2, 2, 159, 162, 3, 2, 2, 2, 160, 161, 3, 2, 2, 2, 160, 158, 3, 2, 2, 2, 161, 163, 3, 2, 2, 2, 162, 160, 3, 2, 2, 2, 163, 164, 7, 12, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 8, 22, 2, 2, 166, 44, 3, 2, 2, 2, 167, 168, 7, 49, 2, 2, 168, 169, 7, 44, 2, 2, 169, 173, 3, 2, 2, 2, 170, 172, 11, 2, 2, 2, 171, 170, 3, 2, 2, 2, 172, 175, 3, 2, 2, 2, 173, 174, 3, 2, 2, 2, 173, 171, 3, 2, 2, 2, 174, 176, 3, 2, 2, 2, 175, 173, 3, 2, 2, 2, 176, 177, 7, 44, 2, 2, 177, 178, 7, 49, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 8, 23, 2, 2, 180, 46, 3, 2, 2, 2, 181, 182, 7, 44, 2, 2, 182, 48, 3, 2, 2, 2, 183, 184, 7, 49, 2, 2, 184, 50, 3, 2, 2, 2, 185, 186, 7, 45, 2, 2, 186, 52, 3, 2, 2, 2, 187, 188, 7, 47, 2, 2, 188, 54, 3, 2, 2, 2, 189, 190, 7, 62, 2, 2, 190, 56, 3, 2, 2, 2, 191, 192, 7, 64, 2, 2, 192, 58, 3, 2, 2, 2, 193, 194, 7, 62, 2, 2, 194, 195, 7, 63, 2, 2, 195, 60, 3, 2, 2, 2, 196, 197, 7, 64, 2, 2, 197, 198, 7, 63, 2, 2, 198, 62, 3, 2, 2, 2, 199, 200, 7, 63, 2, 2, 200, 201, 7, 63, 2, 2, 201, 64, 3, 2, 2, 2, 202, 203, 7, 35, 2, 2, 203, 204, 7, 63, 2, 2, 204, 66, 3, 2, 2, 2, 205, 206, 7, 66, 2, 2, 206, 207, 5, 83, 42, 2, 207, 68, 3, 2, 2, 2, 208, 209, 7, 39, 2, 2, 209, 210, 5, 83, 42, 2, 210, 70, 3, 2, 2, 2, 211, 212, 7, 39, 2, 2, 212, 213, 5, 79, 40, 2, 213, 72, 3, 2, 2, 2, 214, 215, 7, 111, 2, 2, 215, 216, 7, 119, 2, 2, 216, 217, 7, 118, 2, 2, 217, 74, 3, 2, 2, 2, 218, 219, 7, 86, 2, 2, 219, 220, 7, 116, 2, 2, 220, 221, 7, 119, 2, 2, 221, 228, 7, 103, 2, 2, 222, 223, 7, 72, 2, 2, 223, 224, 7, 99, 2, 2, 224, 225, 7, 110, 2, 2, 225, 226, 7, 117, 2, 2, 226, 228, 7, 103, 2, 2, 227, 218, 3, 2, 2, 2, 227, 222, 3, 2, 2, 2, 228, 76, 3, 2, 2, 2, 229, 230, 5, 79, 40, 2, 230, 231, 7, 48, 2, 2, 231, 233, 5, 79, 40, 2, 232, 234, 5, 81, 41, 2, 233, 232, 3, 2, 2, 2, 233, 234, 3, 2, 2, 2, 234, 239, 3, 2, 2, 2, 235, 236, 5, 79, 40, 2, 236, 237, 5, 81, 41, 2, 237, 239, 3, 2, 2, 2, 238, 229, 3, 2, 2, 2, 238, 235, 3, 2, 2, 2, 239, 78, 3, 2, 2, 2, 240, 242, 5, 87, 44, 2, 241, 240, 3, 2, 2, 2, 242, 243, 3, 2, 2, 2, 243, 241, 3, 2, 2, 2, 243, 244, 3, 2, 2, 2, 244, 80, 3, 2, 2, 2, 245, 247, 9, 3, 2, 2, 246, 248, 9, 4, 2, 2, 247, 246, 3, 2, 2, 2, 247, 248, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 250, 5, 79, 40, 2, 250, 82, 3, 2, 2, 2, 251, 254, 7, 97, 2, 2, 252, 254, 5, 85, 43, 2, 253, 251, 3, 2, 2, 2, 253, 252, 3, 2, 2, 2, 254, 260, 3, 2, 2, 2, 255, 259, 7, 97, 2, 2, 256, 259, 5, 85, 43, 2, 257, 259, 5, 87, 44, 2, 258, 255, 3, 2, 2, 2, 258, 256, 3, 2, 2, 2, 258, 257, 3, 2, 2, 2, 259, 262, 3, 2, 2, 2, 260, 258, 3, 2, 2, 2, 260, 261, 3, 2, 2, 2, 261, 84, 3, 2, 2, 2, 262, 260, 3, 2, 2, 2, 263, 264, 9, 5, 2, 2, 264, 86, 3, 2, 2, 2, 265, 266, 9, 6, 2, 2, 266, 88, 3, 2, 2, 2, 14, 2, 150, 160, 173, 227, 233, 238, 243, 247, 253, 258, 260, 3, 8, 2, 2] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py new file mode 100644 index 000000000000..fbf74bf1411b --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayLexer.py @@ -0,0 +1,203 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +from antlr4 import * +from io import StringIO +from typing.io import TextIO +import sys + + + +def serializedATN(): + with StringIO() as buf: + buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2*") + buf.write("\u010b\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") + buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") + buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23") + buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30") + buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36") + buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%") + buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\3\2\3\2\3") + buf.write("\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7\3\b\3\b\3\b") + buf.write("\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13\3\f\3\f\3\r") + buf.write("\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20\3\20") + buf.write("\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22\3\22\3\23\3\23") + buf.write("\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3\25\6\25\u0095\n") + buf.write("\25\r\25\16\25\u0096\3\25\3\25\3\26\3\26\3\26\3\26\7\26") + buf.write("\u009f\n\26\f\26\16\26\u00a2\13\26\3\26\3\26\3\26\3\26") + buf.write("\3\27\3\27\3\27\3\27\7\27\u00ac\n\27\f\27\16\27\u00af") + buf.write("\13\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\31\3\31\3") + buf.write("\32\3\32\3\33\3\33\3\34\3\34\3\35\3\35\3\36\3\36\3\36") + buf.write("\3\37\3\37\3\37\3 \3 \3 \3!\3!\3!\3\"\3\"\3\"\3#\3#\3") + buf.write("#\3$\3$\3$\3%\3%\3%\3%\3&\3&\3&\3&\3&\3&\3&\3&\3&\5&\u00e4") + buf.write("\n&\3\'\3\'\3\'\3\'\5\'\u00ea\n\'\3\'\3\'\3\'\5\'\u00ef") + buf.write("\n\'\3(\6(\u00f2\n(\r(\16(\u00f3\3)\3)\5)\u00f8\n)\3)") + buf.write("\3)\3*\3*\5*\u00fe\n*\3*\3*\3*\7*\u0103\n*\f*\16*\u0106") + buf.write("\13*\3+\3+\3,\3,\4\u00a0\u00ad\2-\3\3\5\4\7\5\t\6\13\7") + buf.write("\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21") + buf.write("!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67") + buf.write("\359\36;\37= ?!A\"C#E$G%I&K\'M(O)Q\2S*U\2W\2\3\2\7\5\2") + buf.write("\13\f\17\17\"\"\4\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0113") + buf.write("\2\3\3\2\2\2\2\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13") + buf.write("\3\2\2\2\2\r\3\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3") + buf.write("\2\2\2\2\25\3\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2") + buf.write("\2\2\2\35\3\2\2\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2") + buf.write("%\3\2\2\2\2\'\3\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2") + buf.write("\2/\3\2\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67") + buf.write("\3\2\2\2\29\3\2\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2") + buf.write("A\3\2\2\2\2C\3\2\2\2\2E\3\2\2\2\2G\3\2\2\2\2I\3\2\2\2") + buf.write("\2K\3\2\2\2\2M\3\2\2\2\2O\3\2\2\2\2S\3\2\2\2\3Y\3\2\2") + buf.write("\2\5[\3\2\2\2\7]\3\2\2\2\t_\3\2\2\2\13a\3\2\2\2\rc\3\2") + buf.write("\2\2\17f\3\2\2\2\21k\3\2\2\2\23o\3\2\2\2\25q\3\2\2\2\27") + buf.write("s\3\2\2\2\31u\3\2\2\2\33w\3\2\2\2\35z\3\2\2\2\37}\3\2") + buf.write("\2\2!\u0081\3\2\2\2#\u0083\3\2\2\2%\u008a\3\2\2\2\'\u008c") + buf.write("\3\2\2\2)\u0094\3\2\2\2+\u009a\3\2\2\2-\u00a7\3\2\2\2") + buf.write("/\u00b5\3\2\2\2\61\u00b7\3\2\2\2\63\u00b9\3\2\2\2\65\u00bb") + buf.write("\3\2\2\2\67\u00bd\3\2\2\29\u00bf\3\2\2\2;\u00c1\3\2\2") + buf.write("\2=\u00c4\3\2\2\2?\u00c7\3\2\2\2A\u00ca\3\2\2\2C\u00cd") + buf.write("\3\2\2\2E\u00d0\3\2\2\2G\u00d3\3\2\2\2I\u00d6\3\2\2\2") + buf.write("K\u00e3\3\2\2\2M\u00ee\3\2\2\2O\u00f1\3\2\2\2Q\u00f5\3") + buf.write("\2\2\2S\u00fd\3\2\2\2U\u0107\3\2\2\2W\u0109\3\2\2\2YZ") + buf.write("\7*\2\2Z\4\3\2\2\2[\\\7+\2\2\\\6\3\2\2\2]^\7.\2\2^\b\3") + buf.write("\2\2\2_`\7]\2\2`\n\3\2\2\2ab\7_\2\2b\f\3\2\2\2cd\7k\2") + buf.write("\2de\7h\2\2e\16\3\2\2\2fg\7g\2\2gh\7n\2\2hi\7u\2\2ij\7") + buf.write("g\2\2j\20\3\2\2\2kl\7n\2\2lm\7g\2\2mn\7v\2\2n\22\3\2\2") + buf.write("\2op\7?\2\2p\24\3\2\2\2qr\7=\2\2r\26\3\2\2\2st\7}\2\2") + buf.write("t\30\3\2\2\2uv\7\177\2\2v\32\3\2\2\2wx\7h\2\2xy\7p\2\2") + buf.write("y\34\3\2\2\2z{\7/\2\2{|\7@\2\2|\36\3\2\2\2}~\7f\2\2~\177") + buf.write("\7g\2\2\177\u0080\7h\2\2\u0080 \3\2\2\2\u0081\u0082\7") + buf.write("<\2\2\u0082\"\3\2\2\2\u0083\u0084\7V\2\2\u0084\u0085\7") + buf.write("g\2\2\u0085\u0086\7p\2\2\u0086\u0087\7u\2\2\u0087\u0088") + buf.write("\7q\2\2\u0088\u0089\7t\2\2\u0089$\3\2\2\2\u008a\u008b") + buf.write("\7a\2\2\u008b&\3\2\2\2\u008c\u008d\7x\2\2\u008d\u008e") + buf.write("\7\62\2\2\u008e\u008f\7\60\2\2\u008f\u0090\7\62\2\2\u0090") + buf.write("\u0091\7\60\2\2\u0091\u0092\7\64\2\2\u0092(\3\2\2\2\u0093") + buf.write("\u0095\t\2\2\2\u0094\u0093\3\2\2\2\u0095\u0096\3\2\2\2") + buf.write("\u0096\u0094\3\2\2\2\u0096\u0097\3\2\2\2\u0097\u0098\3") + buf.write("\2\2\2\u0098\u0099\b\25\2\2\u0099*\3\2\2\2\u009a\u009b") + buf.write("\7\61\2\2\u009b\u009c\7\61\2\2\u009c\u00a0\3\2\2\2\u009d") + buf.write("\u009f\13\2\2\2\u009e\u009d\3\2\2\2\u009f\u00a2\3\2\2") + buf.write("\2\u00a0\u00a1\3\2\2\2\u00a0\u009e\3\2\2\2\u00a1\u00a3") + buf.write("\3\2\2\2\u00a2\u00a0\3\2\2\2\u00a3\u00a4\7\f\2\2\u00a4") + buf.write("\u00a5\3\2\2\2\u00a5\u00a6\b\26\2\2\u00a6,\3\2\2\2\u00a7") + buf.write("\u00a8\7\61\2\2\u00a8\u00a9\7,\2\2\u00a9\u00ad\3\2\2\2") + buf.write("\u00aa\u00ac\13\2\2\2\u00ab\u00aa\3\2\2\2\u00ac\u00af") + buf.write("\3\2\2\2\u00ad\u00ae\3\2\2\2\u00ad\u00ab\3\2\2\2\u00ae") + buf.write("\u00b0\3\2\2\2\u00af\u00ad\3\2\2\2\u00b0\u00b1\7,\2\2") + buf.write("\u00b1\u00b2\7\61\2\2\u00b2\u00b3\3\2\2\2\u00b3\u00b4") + buf.write("\b\27\2\2\u00b4.\3\2\2\2\u00b5\u00b6\7,\2\2\u00b6\60\3") + buf.write("\2\2\2\u00b7\u00b8\7\61\2\2\u00b8\62\3\2\2\2\u00b9\u00ba") + buf.write("\7-\2\2\u00ba\64\3\2\2\2\u00bb\u00bc\7/\2\2\u00bc\66\3") + buf.write("\2\2\2\u00bd\u00be\7>\2\2\u00be8\3\2\2\2\u00bf\u00c0\7") + buf.write("@\2\2\u00c0:\3\2\2\2\u00c1\u00c2\7>\2\2\u00c2\u00c3\7") + buf.write("?\2\2\u00c3<\3\2\2\2\u00c4\u00c5\7@\2\2\u00c5\u00c6\7") + buf.write("?\2\2\u00c6>\3\2\2\2\u00c7\u00c8\7?\2\2\u00c8\u00c9\7") + buf.write("?\2\2\u00c9@\3\2\2\2\u00ca\u00cb\7#\2\2\u00cb\u00cc\7") + buf.write("?\2\2\u00ccB\3\2\2\2\u00cd\u00ce\7B\2\2\u00ce\u00cf\5") + buf.write("S*\2\u00cfD\3\2\2\2\u00d0\u00d1\7\'\2\2\u00d1\u00d2\5") + buf.write("S*\2\u00d2F\3\2\2\2\u00d3\u00d4\7\'\2\2\u00d4\u00d5\5") + buf.write("O(\2\u00d5H\3\2\2\2\u00d6\u00d7\7o\2\2\u00d7\u00d8\7w") + buf.write("\2\2\u00d8\u00d9\7v\2\2\u00d9J\3\2\2\2\u00da\u00db\7V") + buf.write("\2\2\u00db\u00dc\7t\2\2\u00dc\u00dd\7w\2\2\u00dd\u00e4") + buf.write("\7g\2\2\u00de\u00df\7H\2\2\u00df\u00e0\7c\2\2\u00e0\u00e1") + buf.write("\7n\2\2\u00e1\u00e2\7u\2\2\u00e2\u00e4\7g\2\2\u00e3\u00da") + buf.write("\3\2\2\2\u00e3\u00de\3\2\2\2\u00e4L\3\2\2\2\u00e5\u00e6") + buf.write("\5O(\2\u00e6\u00e7\7\60\2\2\u00e7\u00e9\5O(\2\u00e8\u00ea") + buf.write("\5Q)\2\u00e9\u00e8\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea\u00ef") + buf.write("\3\2\2\2\u00eb\u00ec\5O(\2\u00ec\u00ed\5Q)\2\u00ed\u00ef") + buf.write("\3\2\2\2\u00ee\u00e5\3\2\2\2\u00ee\u00eb\3\2\2\2\u00ef") + buf.write("N\3\2\2\2\u00f0\u00f2\5W,\2\u00f1\u00f0\3\2\2\2\u00f2") + buf.write("\u00f3\3\2\2\2\u00f3\u00f1\3\2\2\2\u00f3\u00f4\3\2\2\2") + buf.write("\u00f4P\3\2\2\2\u00f5\u00f7\t\3\2\2\u00f6\u00f8\t\4\2") + buf.write("\2\u00f7\u00f6\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") + buf.write("\3\2\2\2\u00f9\u00fa\5O(\2\u00faR\3\2\2\2\u00fb\u00fe") + buf.write("\7a\2\2\u00fc\u00fe\5U+\2\u00fd\u00fb\3\2\2\2\u00fd\u00fc") + buf.write("\3\2\2\2\u00fe\u0104\3\2\2\2\u00ff\u0103\7a\2\2\u0100") + buf.write("\u0103\5U+\2\u0101\u0103\5W,\2\u0102\u00ff\3\2\2\2\u0102") + buf.write("\u0100\3\2\2\2\u0102\u0101\3\2\2\2\u0103\u0106\3\2\2\2") + buf.write("\u0104\u0102\3\2\2\2\u0104\u0105\3\2\2\2\u0105T\3\2\2") + buf.write("\2\u0106\u0104\3\2\2\2\u0107\u0108\t\5\2\2\u0108V\3\2") + buf.write("\2\2\u0109\u010a\t\6\2\2\u010aX\3\2\2\2\16\2\u0096\u00a0") + buf.write("\u00ad\u00e3\u00e9\u00ee\u00f3\u00f7\u00fd\u0102\u0104") + buf.write("\3\b\2\2") + return buf.getvalue() + + +class RelayLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + SEMVER = 19 + WS = 20 + LINE_COMMENT = 21 + COMMENT = 22 + MUL = 23 + DIV = 24 + ADD = 25 + SUB = 26 + LT = 27 + GT = 28 + LE = 29 + GE = 30 + EQ = 31 + NE = 32 + GLOBAL_VAR = 33 + LOCAL_VAR = 34 + GRAPH_VAR = 35 + MUT = 36 + BOOL_LIT = 37 + FLOAT = 38 + NAT = 39 + CNAME = 40 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "'('", "')'", "','", "'['", "']'", "'if'", "'else'", "'let'", + "'='", "';'", "'{'", "'}'", "'fn'", "'->'", "'def'", "':'", + "'Tensor'", "'_'", "'v0.0.2'", "'*'", "'/'", "'+'", "'-'", "'<'", + "'>'", "'<='", "'>='", "'=='", "'!='", "'mut'" ] + + symbolicNames = [ "", + "SEMVER", "WS", "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD", + "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", + "GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", "NAT", "CNAME" ] + + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", + "T__14", "T__15", "T__16", "T__17", "SEMVER", "WS", "LINE_COMMENT", + "COMMENT", "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", + "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR", + "MUT", "BOOL_LIT", "FLOAT", "NAT", "EXP", "CNAME", "LETTER", + "DIGIT" ] + + grammarFileName = "Relay.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.7.2") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/python/tvm/relay/grammar/py3/RelayLexer.tokens b/python/tvm/relay/grammar/py3/RelayLexer.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayLexer.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py new file mode 100644 index 000000000000..ff5cffc36a9f --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayParser.py @@ -0,0 +1,2307 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# encoding: utf-8 +from antlr4 import * +from io import StringIO +from typing.io import TextIO +import sys + + +def serializedATN(): + with StringIO() as buf: + buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3*") + buf.write("\u014c\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") + buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16") + buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23") + buf.write("\3\2\3\2\3\3\3\3\7\3+\n\3\f\3\16\3.\13\3\3\3\5\3\61\n") + buf.write("\3\3\3\3\3\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write("\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\6\4H\n\4\r\4\16\4I\3") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\7\4R\n\4\f\4\16\4U\13\4\5\4W\n") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4d\n") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4n\n\4\3\4\3\4\3") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write("\5\4\u0080\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write("\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\7\4\u0096\n\4") + buf.write("\f\4\16\4\u0099\13\4\5\4\u009b\n\4\3\4\7\4\u009e\n\4\f") + buf.write("\4\16\4\u00a1\13\4\3\5\3\5\5\5\u00a5\n\5\3\5\3\5\3\5\3") + buf.write("\5\3\5\5\5\u00ac\n\5\3\5\3\5\3\6\3\6\3\6\5\6\u00b3\n\6") + buf.write("\3\6\3\6\3\6\3\6\3\6\5\6\u00ba\n\6\3\6\3\6\3\7\3\7\3\7") + buf.write("\3\7\3\7\3\7\5\7\u00c4\n\7\3\b\3\b\3\b\7\b\u00c9\n\b\f") + buf.write("\b\16\b\u00cc\13\b\5\b\u00ce\n\b\3\t\3\t\3\t\5\t\u00d3") + buf.write("\n\t\3\n\3\n\3\n\7\n\u00d8\n\n\f\n\16\n\u00db\13\n\5\n") + buf.write("\u00dd\n\n\3\13\3\13\3\13\3\13\3\f\3\f\3\f\3\f\3\f\3\f") + buf.write("\7\f\u00e9\n\f\f\f\16\f\u00ec\13\f\3\f\3\f\5\f\u00f0\n") + buf.write("\f\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\6\r\u00fd") + buf.write("\n\r\r\r\16\r\u00fe\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3") + buf.write("\r\3\r\3\r\3\r\5\r\u010d\n\r\3\r\3\r\3\r\3\r\7\r\u0113") + buf.write("\n\r\f\r\16\r\u0116\13\r\5\r\u0118\n\r\3\r\3\r\3\r\3\r") + buf.write("\3\r\5\r\u011f\n\r\3\16\3\16\3\16\3\16\3\16\3\16\3\16") + buf.write("\3\16\3\16\3\16\3\16\6\16\u012c\n\16\r\16\16\16\u012d") + buf.write("\3\16\3\16\5\16\u0132\n\16\3\17\3\17\3\17\3\17\3\17\5") + buf.write("\17\u0139\n\17\3\20\3\20\3\21\3\21\3\21\3\21\3\22\3\22") + buf.write("\3\22\5\22\u0144\n\22\3\23\3\23\3\23\3\23\5\23\u014a\n") + buf.write("\23\3\23\2\3\6\24\2\4\6\b\n\f\16\20\22\24\26\30\32\34") + buf.write("\36 \"$\2\6\3\2\31\32\3\2\33\34\3\2\35 \3\2!\"\2\u0175") + buf.write("\2&\3\2\2\2\4(\3\2\2\2\6\177\3\2\2\2\b\u00a2\3\2\2\2\n") + buf.write("\u00af\3\2\2\2\f\u00c3\3\2\2\2\16\u00cd\3\2\2\2\20\u00cf") + buf.write("\3\2\2\2\22\u00dc\3\2\2\2\24\u00de\3\2\2\2\26\u00ef\3") + buf.write("\2\2\2\30\u011e\3\2\2\2\32\u0131\3\2\2\2\34\u0138\3\2") + buf.write("\2\2\36\u013a\3\2\2\2 \u013c\3\2\2\2\"\u0143\3\2\2\2$") + buf.write("\u0149\3\2\2\2&\'\7*\2\2\'\3\3\2\2\2(\60\7\25\2\2)+\5") + buf.write("\n\6\2*)\3\2\2\2+.\3\2\2\2,*\3\2\2\2,-\3\2\2\2-\61\3\2") + buf.write("\2\2.,\3\2\2\2/\61\5\6\4\2\60,\3\2\2\2\60/\3\2\2\2\61") + buf.write("\62\3\2\2\2\62\63\7\2\2\3\63\5\3\2\2\2\64\65\b\4\1\2\65") + buf.write("\66\7\3\2\2\66\67\5\6\4\2\678\7\4\2\28\u0080\3\2\2\29") + buf.write(":\7\34\2\2:\u0080\5\6\4\23;\u0080\5\b\5\2<=\7\3\2\2=\u0080") + buf.write("\7\4\2\2>?\7\3\2\2?@\5\6\4\2@A\7\5\2\2AB\7\4\2\2B\u0080") + buf.write("\3\2\2\2CD\7\3\2\2DG\5\6\4\2EF\7\5\2\2FH\5\6\4\2GE\3\2") + buf.write("\2\2HI\3\2\2\2IG\3\2\2\2IJ\3\2\2\2JK\3\2\2\2KL\7\4\2\2") + buf.write("L\u0080\3\2\2\2MV\7\6\2\2NS\5\6\4\2OP\7\5\2\2PR\5\6\4") + buf.write("\2QO\3\2\2\2RU\3\2\2\2SQ\3\2\2\2ST\3\2\2\2TW\3\2\2\2U") + buf.write("S\3\2\2\2VN\3\2\2\2VW\3\2\2\2WX\3\2\2\2X\u0080\7\7\2\2") + buf.write("YZ\7\b\2\2Z[\7\3\2\2[\\\5\6\4\2\\]\7\4\2\2]^\5 \21\2^") + buf.write("_\7\t\2\2_`\5 \21\2`\u0080\3\2\2\2ac\7\n\2\2bd\7&\2\2") + buf.write("cb\3\2\2\2cd\3\2\2\2de\3\2\2\2ef\5\20\t\2fg\7\13\2\2g") + buf.write("h\5\6\4\2hi\7\f\2\2ij\5\6\4\bj\u0080\3\2\2\2km\7\n\2\2") + buf.write("ln\7&\2\2ml\3\2\2\2mn\3\2\2\2no\3\2\2\2op\5\20\t\2pq\7") + buf.write("\13\2\2qr\7\r\2\2rs\5\6\4\2st\7\16\2\2tu\7\f\2\2uv\5\6") + buf.write("\4\7v\u0080\3\2\2\2wx\5$\23\2xy\7\13\2\2yz\5\6\4\2z{\7") + buf.write("\f\2\2{|\5\6\4\5|\u0080\3\2\2\2}\u0080\5$\23\2~\u0080") + buf.write("\5\"\22\2\177\64\3\2\2\2\1779\3\2\2\2\177;\3\2\2\2\177") + buf.write("<\3\2\2\2\177>\3\2\2\2\177C\3\2\2\2\177M\3\2\2\2\177Y") + buf.write("\3\2\2\2\177a\3\2\2\2\177k\3\2\2\2\177w\3\2\2\2\177}\3") + buf.write("\2\2\2\177~\3\2\2\2\u0080\u009f\3\2\2\2\u0081\u0082\f") + buf.write("\22\2\2\u0082\u0083\t\2\2\2\u0083\u009e\5\6\4\23\u0084") + buf.write("\u0085\f\21\2\2\u0085\u0086\t\3\2\2\u0086\u009e\5\6\4") + buf.write("\22\u0087\u0088\f\20\2\2\u0088\u0089\t\4\2\2\u0089\u009e") + buf.write("\5\6\4\21\u008a\u008b\f\17\2\2\u008b\u008c\t\5\2\2\u008c") + buf.write("\u009e\5\6\4\20\u008d\u008e\f\6\2\2\u008e\u008f\7\f\2") + buf.write("\2\u008f\u009e\5\6\4\7\u0090\u0091\f\24\2\2\u0091\u009a") + buf.write("\7\3\2\2\u0092\u0097\5\6\4\2\u0093\u0094\7\5\2\2\u0094") + buf.write("\u0096\5\6\4\2\u0095\u0093\3\2\2\2\u0096\u0099\3\2\2\2") + buf.write("\u0097\u0095\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u009b\3") + buf.write("\2\2\2\u0099\u0097\3\2\2\2\u009a\u0092\3\2\2\2\u009a\u009b") + buf.write("\3\2\2\2\u009b\u009c\3\2\2\2\u009c\u009e\7\4\2\2\u009d") + buf.write("\u0081\3\2\2\2\u009d\u0084\3\2\2\2\u009d\u0087\3\2\2\2") + buf.write("\u009d\u008a\3\2\2\2\u009d\u008d\3\2\2\2\u009d\u0090\3") + buf.write("\2\2\2\u009e\u00a1\3\2\2\2\u009f\u009d\3\2\2\2\u009f\u00a0") + buf.write("\3\2\2\2\u00a0\7\3\2\2\2\u00a1\u009f\3\2\2\2\u00a2\u00a4") + buf.write("\7\17\2\2\u00a3\u00a5\5\26\f\2\u00a4\u00a3\3\2\2\2\u00a4") + buf.write("\u00a5\3\2\2\2\u00a5\u00a6\3\2\2\2\u00a6\u00a7\7\3\2\2") + buf.write("\u00a7\u00a8\5\f\7\2\u00a8\u00ab\7\4\2\2\u00a9\u00aa\7") + buf.write("\20\2\2\u00aa\u00ac\5\30\r\2\u00ab\u00a9\3\2\2\2\u00ab") + buf.write("\u00ac\3\2\2\2\u00ac\u00ad\3\2\2\2\u00ad\u00ae\5 \21\2") + buf.write("\u00ae\t\3\2\2\2\u00af\u00b0\7\21\2\2\u00b0\u00b2\5$\23") + buf.write("\2\u00b1\u00b3\5\26\f\2\u00b2\u00b1\3\2\2\2\u00b2\u00b3") + buf.write("\3\2\2\2\u00b3\u00b4\3\2\2\2\u00b4\u00b5\7\3\2\2\u00b5") + buf.write("\u00b6\5\f\7\2\u00b6\u00b9\7\4\2\2\u00b7\u00b8\7\20\2") + buf.write("\2\u00b8\u00ba\5\30\r\2\u00b9\u00b7\3\2\2\2\u00b9\u00ba") + buf.write("\3\2\2\2\u00ba\u00bb\3\2\2\2\u00bb\u00bc\5 \21\2\u00bc") + buf.write("\13\3\2\2\2\u00bd\u00c4\5\16\b\2\u00be\u00c4\5\22\n\2") + buf.write("\u00bf\u00c0\5\16\b\2\u00c0\u00c1\7\5\2\2\u00c1\u00c2") + buf.write("\5\22\n\2\u00c2\u00c4\3\2\2\2\u00c3\u00bd\3\2\2\2\u00c3") + buf.write("\u00be\3\2\2\2\u00c3\u00bf\3\2\2\2\u00c4\r\3\2\2\2\u00c5") + buf.write("\u00ca\5\20\t\2\u00c6\u00c7\7\5\2\2\u00c7\u00c9\5\20\t") + buf.write("\2\u00c8\u00c6\3\2\2\2\u00c9\u00cc\3\2\2\2\u00ca\u00c8") + buf.write("\3\2\2\2\u00ca\u00cb\3\2\2\2\u00cb\u00ce\3\2\2\2\u00cc") + buf.write("\u00ca\3\2\2\2\u00cd\u00c5\3\2\2\2\u00cd\u00ce\3\2\2\2") + buf.write("\u00ce\17\3\2\2\2\u00cf\u00d2\5$\23\2\u00d0\u00d1\7\22") + buf.write("\2\2\u00d1\u00d3\5\30\r\2\u00d2\u00d0\3\2\2\2\u00d2\u00d3") + buf.write("\3\2\2\2\u00d3\21\3\2\2\2\u00d4\u00d9\5\24\13\2\u00d5") + buf.write("\u00d6\7\5\2\2\u00d6\u00d8\5\24\13\2\u00d7\u00d5\3\2\2") + buf.write("\2\u00d8\u00db\3\2\2\2\u00d9\u00d7\3\2\2\2\u00d9\u00da") + buf.write("\3\2\2\2\u00da\u00dd\3\2\2\2\u00db\u00d9\3\2\2\2\u00dc") + buf.write("\u00d4\3\2\2\2\u00dc\u00dd\3\2\2\2\u00dd\23\3\2\2\2\u00de") + buf.write("\u00df\7*\2\2\u00df\u00e0\7\13\2\2\u00e0\u00e1\5\6\4\2") + buf.write("\u00e1\25\3\2\2\2\u00e2\u00e3\7\6\2\2\u00e3\u00f0\7\7") + buf.write("\2\2\u00e4\u00e5\7\6\2\2\u00e5\u00ea\5$\23\2\u00e6\u00e7") + buf.write("\7\5\2\2\u00e7\u00e9\5$\23\2\u00e8\u00e6\3\2\2\2\u00e9") + buf.write("\u00ec\3\2\2\2\u00ea\u00e8\3\2\2\2\u00ea\u00eb\3\2\2\2") + buf.write("\u00eb\u00ed\3\2\2\2\u00ec\u00ea\3\2\2\2\u00ed\u00ee\7") + buf.write("\7\2\2\u00ee\u00f0\3\2\2\2\u00ef\u00e2\3\2\2\2\u00ef\u00e4") + buf.write("\3\2\2\2\u00f0\27\3\2\2\2\u00f1\u00f2\7\3\2\2\u00f2\u011f") + buf.write("\7\4\2\2\u00f3\u00f4\7\3\2\2\u00f4\u00f5\5\30\r\2\u00f5") + buf.write("\u00f6\7\5\2\2\u00f6\u00f7\7\4\2\2\u00f7\u011f\3\2\2\2") + buf.write("\u00f8\u00f9\7\3\2\2\u00f9\u00fc\5\30\r\2\u00fa\u00fb") + buf.write("\7\5\2\2\u00fb\u00fd\5\30\r\2\u00fc\u00fa\3\2\2\2\u00fd") + buf.write("\u00fe\3\2\2\2\u00fe\u00fc\3\2\2\2\u00fe\u00ff\3\2\2\2") + buf.write("\u00ff\u0100\3\2\2\2\u0100\u0101\7\4\2\2\u0101\u011f\3") + buf.write("\2\2\2\u0102\u011f\5\36\20\2\u0103\u0104\7\23\2\2\u0104") + buf.write("\u0105\7\6\2\2\u0105\u0106\5\32\16\2\u0106\u0107\7\5\2") + buf.write("\2\u0107\u0108\5\30\r\2\u0108\u0109\7\7\2\2\u0109\u011f") + buf.write("\3\2\2\2\u010a\u010c\7\17\2\2\u010b\u010d\5\26\f\2\u010c") + buf.write("\u010b\3\2\2\2\u010c\u010d\3\2\2\2\u010d\u010e\3\2\2\2") + buf.write("\u010e\u0117\7\3\2\2\u010f\u0114\5\30\r\2\u0110\u0111") + buf.write("\7\5\2\2\u0111\u0113\5\30\r\2\u0112\u0110\3\2\2\2\u0113") + buf.write("\u0116\3\2\2\2\u0114\u0112\3\2\2\2\u0114\u0115\3\2\2\2") + buf.write("\u0115\u0118\3\2\2\2\u0116\u0114\3\2\2\2\u0117\u010f\3") + buf.write("\2\2\2\u0117\u0118\3\2\2\2\u0118\u0119\3\2\2\2\u0119\u011a") + buf.write("\7\4\2\2\u011a\u011b\7\20\2\2\u011b\u011f\5\30\r\2\u011c") + buf.write("\u011f\7\24\2\2\u011d\u011f\7)\2\2\u011e\u00f1\3\2\2\2") + buf.write("\u011e\u00f3\3\2\2\2\u011e\u00f8\3\2\2\2\u011e\u0102\3") + buf.write("\2\2\2\u011e\u0103\3\2\2\2\u011e\u010a\3\2\2\2\u011e\u011c") + buf.write("\3\2\2\2\u011e\u011d\3\2\2\2\u011f\31\3\2\2\2\u0120\u0121") + buf.write("\7\3\2\2\u0121\u0132\7\4\2\2\u0122\u0123\7\3\2\2\u0123") + buf.write("\u0124\5\34\17\2\u0124\u0125\7\5\2\2\u0125\u0126\7\4\2") + buf.write("\2\u0126\u0132\3\2\2\2\u0127\u0128\7\3\2\2\u0128\u012b") + buf.write("\5\34\17\2\u0129\u012a\7\5\2\2\u012a\u012c\5\34\17\2\u012b") + buf.write("\u0129\3\2\2\2\u012c\u012d\3\2\2\2\u012d\u012b\3\2\2\2") + buf.write("\u012d\u012e\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0130\7") + buf.write("\4\2\2\u0130\u0132\3\2\2\2\u0131\u0120\3\2\2\2\u0131\u0122") + buf.write("\3\2\2\2\u0131\u0127\3\2\2\2\u0132\33\3\2\2\2\u0133\u0134") + buf.write("\7\3\2\2\u0134\u0135\5\34\17\2\u0135\u0136\7\4\2\2\u0136") + buf.write("\u0139\3\2\2\2\u0137\u0139\7)\2\2\u0138\u0133\3\2\2\2") + buf.write("\u0138\u0137\3\2\2\2\u0139\35\3\2\2\2\u013a\u013b\7*\2") + buf.write("\2\u013b\37\3\2\2\2\u013c\u013d\7\r\2\2\u013d\u013e\5") + buf.write("\6\4\2\u013e\u013f\7\16\2\2\u013f!\3\2\2\2\u0140\u0144") + buf.write("\7(\2\2\u0141\u0144\7)\2\2\u0142\u0144\7\'\2\2\u0143\u0140") + buf.write("\3\2\2\2\u0143\u0141\3\2\2\2\u0143\u0142\3\2\2\2\u0144") + buf.write("#\3\2\2\2\u0145\u014a\5\2\2\2\u0146\u014a\7#\2\2\u0147") + buf.write("\u014a\7$\2\2\u0148\u014a\7%\2\2\u0149\u0145\3\2\2\2\u0149") + buf.write("\u0146\3\2\2\2\u0149\u0147\3\2\2\2\u0149\u0148\3\2\2\2") + buf.write("\u014a%\3\2\2\2$,\60ISVcm\177\u0097\u009a\u009d\u009f") + buf.write("\u00a4\u00ab\u00b2\u00b9\u00c3\u00ca\u00cd\u00d2\u00d9") + buf.write("\u00dc\u00ea\u00ef\u00fe\u010c\u0114\u0117\u011e\u012d") + buf.write("\u0131\u0138\u0143\u0149") + return buf.getvalue() + + +class RelayParser ( Parser ): + + grammarFileName = "Relay.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "'('", "')'", "','", "'['", "']'", "'if'", + "'else'", "'let'", "'='", "';'", "'{'", "'}'", "'fn'", + "'->'", "'def'", "':'", "'Tensor'", "'_'", "'v0.0.2'", + "", "", "", "'*'", "'/'", + "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", + "'!='", "", "", "", "'mut'" ] + + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "SEMVER", "WS", + "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD", "SUB", + "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", + "LOCAL_VAR", "GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", + "NAT", "CNAME" ] + + RULE_opIdent = 0 + RULE_prog = 1 + RULE_expr = 2 + RULE_func = 3 + RULE_defn = 4 + RULE_argList = 5 + RULE_varList = 6 + RULE_var = 7 + RULE_attrList = 8 + RULE_attr = 9 + RULE_typeParamSeq = 10 + RULE_type_ = 11 + RULE_shapeSeq = 12 + RULE_shape = 13 + RULE_typeIdent = 14 + RULE_body = 15 + RULE_scalar = 16 + RULE_ident = 17 + + ruleNames = [ "opIdent", "prog", "expr", "func", "defn", "argList", + "varList", "var", "attrList", "attr", "typeParamSeq", + "type_", "shapeSeq", "shape", "typeIdent", "body", "scalar", + "ident" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + SEMVER=19 + WS=20 + LINE_COMMENT=21 + COMMENT=22 + MUL=23 + DIV=24 + ADD=25 + SUB=26 + LT=27 + GT=28 + LE=29 + GE=30 + EQ=31 + NE=32 + GLOBAL_VAR=33 + LOCAL_VAR=34 + GRAPH_VAR=35 + MUT=36 + BOOL_LIT=37 + FLOAT=38 + NAT=39 + CNAME=40 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.7.2") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class OpIdentContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_opIdent + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitOpIdent" ): + return visitor.visitOpIdent(self) + else: + return visitor.visitChildren(self) + + + + + def opIdent(self): + + localctx = RelayParser.OpIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_opIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 36 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ProgContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def SEMVER(self): + return self.getToken(RelayParser.SEMVER, 0) + + def EOF(self): + return self.getToken(RelayParser.EOF, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def defn(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.DefnContext) + else: + return self.getTypedRuleContext(RelayParser.DefnContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_prog + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitProg" ): + return visitor.visitProg(self) + else: + return visitor.visitChildren(self) + + + + + def prog(self): + + localctx = RelayParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 38 + self.match(RelayParser.SEMVER) + self.state = 46 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.EOF, RelayParser.T__14]: + self.state = 42 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__14: + self.state = 39 + self.defn() + self.state = 44 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [RelayParser.T__0, RelayParser.T__3, RelayParser.T__5, RelayParser.T__7, RelayParser.T__12, RelayParser.SUB, RelayParser.GLOBAL_VAR, RelayParser.LOCAL_VAR, RelayParser.GRAPH_VAR, RelayParser.BOOL_LIT, RelayParser.FLOAT, RelayParser.NAT, RelayParser.CNAME]: + self.state = 45 + self.expr(0) + pass + else: + raise NoViableAltException(self) + + self.state = 48 + self.match(RelayParser.EOF) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_expr + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + class IdentExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIdentExpr" ): + return visitor.visitIdentExpr(self) + else: + return visitor.visitChildren(self) + + + class CallContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitCall" ): + return visitor.visitCall(self) + else: + return visitor.visitChildren(self) + + + class NegContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitNeg" ): + return visitor.visitNeg(self) + else: + return visitor.visitChildren(self) + + + class TupleContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTuple" ): + return visitor.visitTuple(self) + else: + return visitor.visitChildren(self) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitParens" ): + return visitor.visitParens(self) + else: + return visitor.visitChildren(self) + + + class FuncExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def func(self): + return self.getTypedRuleContext(RelayParser.FuncContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFuncExpr" ): + return visitor.visitFuncExpr(self) + else: + return visitor.visitChildren(self) + + + class ScalarExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def scalar(self): + return self.getTypedRuleContext(RelayParser.ScalarContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarExpr" ): + return visitor.visitScalarExpr(self) + else: + return visitor.visitChildren(self) + + + class LetContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def var(self): + return self.getTypedRuleContext(RelayParser.VarContext,0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUT(self): + return self.getToken(RelayParser.MUT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitLet" ): + return visitor.visitLet(self) + else: + return visitor.visitChildren(self) + + + class TensorContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTensor" ): + return visitor.visitTensor(self) + else: + return visitor.visitChildren(self) + + + class IfElseContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + def body(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.BodyContext) + else: + return self.getTypedRuleContext(RelayParser.BodyContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIfElse" ): + return visitor.visitIfElse(self) + else: + return visitor.visitChildren(self) + + + class GraphContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitGraph" ): + return visitor.visitGraph(self) + else: + return visitor.visitChildren(self) + + + class BinOpContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.op = None # Token + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUL(self): + return self.getToken(RelayParser.MUL, 0) + def DIV(self): + return self.getToken(RelayParser.DIV, 0) + def ADD(self): + return self.getToken(RelayParser.ADD, 0) + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def LT(self): + return self.getToken(RelayParser.LT, 0) + def GT(self): + return self.getToken(RelayParser.GT, 0) + def LE(self): + return self.getToken(RelayParser.LE, 0) + def GE(self): + return self.getToken(RelayParser.GE, 0) + def EQ(self): + return self.getToken(RelayParser.EQ, 0) + def NE(self): + return self.getToken(RelayParser.NE, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitBinOp" ): + return visitor.visitBinOp(self) + else: + return visitor.visitChildren(self) + + + + def expr(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = RelayParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 4 + self.enterRecursionRule(localctx, 4, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 125 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + if la_ == 1: + localctx = RelayParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 51 + self.match(RelayParser.T__0) + self.state = 52 + self.expr(0) + self.state = 53 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.NegContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 55 + self.match(RelayParser.SUB) + self.state = 56 + self.expr(17) + pass + + elif la_ == 3: + localctx = RelayParser.FuncExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 57 + self.func() + pass + + elif la_ == 4: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 58 + self.match(RelayParser.T__0) + self.state = 59 + self.match(RelayParser.T__1) + pass + + elif la_ == 5: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 60 + self.match(RelayParser.T__0) + self.state = 61 + self.expr(0) + self.state = 62 + self.match(RelayParser.T__2) + self.state = 63 + self.match(RelayParser.T__1) + pass + + elif la_ == 6: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 65 + self.match(RelayParser.T__0) + self.state = 66 + self.expr(0) + self.state = 69 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 67 + self.match(RelayParser.T__2) + self.state = 68 + self.expr(0) + self.state = 71 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 73 + self.match(RelayParser.T__1) + pass + + elif la_ == 7: + localctx = RelayParser.TensorContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 75 + self.match(RelayParser.T__3) + self.state = 84 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 76 + self.expr(0) + self.state = 81 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 77 + self.match(RelayParser.T__2) + self.state = 78 + self.expr(0) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 86 + self.match(RelayParser.T__4) + pass + + elif la_ == 8: + localctx = RelayParser.IfElseContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 87 + self.match(RelayParser.T__5) + self.state = 88 + self.match(RelayParser.T__0) + self.state = 89 + self.expr(0) + self.state = 90 + self.match(RelayParser.T__1) + self.state = 91 + self.body() + self.state = 92 + self.match(RelayParser.T__6) + self.state = 93 + self.body() + pass + + elif la_ == 9: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 95 + self.match(RelayParser.T__7) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 96 + self.match(RelayParser.MUT) + + + self.state = 99 + self.var() + self.state = 100 + self.match(RelayParser.T__8) + self.state = 101 + self.expr(0) + self.state = 102 + self.match(RelayParser.T__9) + self.state = 103 + self.expr(6) + pass + + elif la_ == 10: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 105 + self.match(RelayParser.T__7) + self.state = 107 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 106 + self.match(RelayParser.MUT) + + + self.state = 109 + self.var() + self.state = 110 + self.match(RelayParser.T__8) + self.state = 111 + self.match(RelayParser.T__10) + self.state = 112 + self.expr(0) + self.state = 113 + self.match(RelayParser.T__11) + self.state = 114 + self.match(RelayParser.T__9) + self.state = 115 + self.expr(5) + pass + + elif la_ == 11: + localctx = RelayParser.GraphContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 117 + self.ident() + self.state = 118 + self.match(RelayParser.T__8) + self.state = 119 + self.expr(0) + self.state = 120 + self.match(RelayParser.T__9) + self.state = 121 + self.expr(3) + pass + + elif la_ == 12: + localctx = RelayParser.IdentExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 123 + self.ident() + pass + + elif la_ == 13: + localctx = RelayParser.ScalarExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 124 + self.scalar() + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 157 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 155 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,10,self._ctx) + if la_ == 1: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 127 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 128 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.MUL or _la==RelayParser.DIV): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 129 + self.expr(17) + pass + + elif la_ == 2: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 130 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 131 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.ADD or _la==RelayParser.SUB): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 132 + self.expr(16) + pass + + elif la_ == 3: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 133 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 134 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.expr(15) + pass + + elif la_ == 4: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 136 + if not self.precpred(self._ctx, 13): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 13)") + self.state = 137 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.EQ or _la==RelayParser.NE): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 138 + self.expr(14) + pass + + elif la_ == 5: + localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 139 + if not self.precpred(self._ctx, 4): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 140 + self.match(RelayParser.T__9) + self.state = 141 + self.expr(5) + pass + + elif la_ == 6: + localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 142 + if not self.precpred(self._ctx, 18): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") + self.state = 143 + self.match(RelayParser.T__0) + self.state = 152 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 144 + self.expr(0) + self.state = 149 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 145 + self.match(RelayParser.T__2) + self.state = 146 + self.expr(0) + self.state = 151 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 154 + self.match(RelayParser.T__1) + pass + + + self.state = 159 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class FuncContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_func + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFunc" ): + return visitor.visitFunc(self) + else: + return visitor.visitChildren(self) + + + + + def func(self): + + localctx = RelayParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_func) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 160 + self.match(RelayParser.T__12) + self.state = 162 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 161 + self.typeParamSeq() + + + self.state = 164 + self.match(RelayParser.T__0) + self.state = 165 + self.argList() + self.state = 166 + self.match(RelayParser.T__1) + self.state = 169 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 167 + self.match(RelayParser.T__13) + self.state = 168 + self.type_() + + + self.state = 171 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DefnContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_defn + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitDefn" ): + return visitor.visitDefn(self) + else: + return visitor.visitChildren(self) + + + + + def defn(self): + + localctx = RelayParser.DefnContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_defn) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 173 + self.match(RelayParser.T__14) + self.state = 174 + self.ident() + self.state = 176 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 175 + self.typeParamSeq() + + + self.state = 178 + self.match(RelayParser.T__0) + self.state = 179 + self.argList() + self.state = 180 + self.match(RelayParser.T__1) + self.state = 183 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 181 + self.match(RelayParser.T__13) + self.state = 182 + self.type_() + + + self.state = 185 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varList(self): + return self.getTypedRuleContext(RelayParser.VarListContext,0) + + + def attrList(self): + return self.getTypedRuleContext(RelayParser.AttrListContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_argList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitArgList" ): + return visitor.visitArgList(self) + else: + return visitor.visitChildren(self) + + + + + def argList(self): + + localctx = RelayParser.ArgListContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_argList) + try: + self.state = 193 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,16,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 187 + self.varList() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 188 + self.attrList() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 189 + self.varList() + self.state = 190 + self.match(RelayParser.T__2) + self.state = 191 + self.attrList() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def var(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.VarContext) + else: + return self.getTypedRuleContext(RelayParser.VarContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_varList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitVarList" ): + return visitor.visitVarList(self) + else: + return visitor.visitChildren(self) + + + + + def varList(self): + + localctx = RelayParser.VarListContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_varList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.CNAME))) != 0): + self.state = 195 + self.var() + self.state = 200 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 196 + self.match(RelayParser.T__2) + self.state = 197 + self.var() + self.state = 202 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_var + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitVar" ): + return visitor.visitVar(self) + else: + return visitor.visitChildren(self) + + + + + def var(self): + + localctx = RelayParser.VarContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_var) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 205 + self.ident() + self.state = 208 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__15: + self.state = 206 + self.match(RelayParser.T__15) + self.state = 207 + self.type_() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def attr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.AttrContext) + else: + return self.getTypedRuleContext(RelayParser.AttrContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_attrList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitAttrList" ): + return visitor.visitAttrList(self) + else: + return visitor.visitChildren(self) + + + + + def attrList(self): + + localctx = RelayParser.AttrListContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_attrList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.CNAME: + self.state = 210 + self.attr() + self.state = 215 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 211 + self.match(RelayParser.T__2) + self.state = 212 + self.attr() + self.state = 217 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_attr + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitAttr" ): + return visitor.visitAttr(self) + else: + return visitor.visitChildren(self) + + + + + def attr(self): + + localctx = RelayParser.AttrContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_attr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 220 + self.match(RelayParser.CNAME) + self.state = 221 + self.match(RelayParser.T__8) + self.state = 222 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeParamSeqContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ident(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.IdentContext) + else: + return self.getTypedRuleContext(RelayParser.IdentContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_typeParamSeq + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeParamSeq" ): + return visitor.visitTypeParamSeq(self) + else: + return visitor.visitChildren(self) + + + + + def typeParamSeq(self): + + localctx = RelayParser.TypeParamSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_typeParamSeq) + self._la = 0 # Token type + try: + self.state = 237 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,23,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 224 + self.match(RelayParser.T__3) + self.state = 225 + self.match(RelayParser.T__4) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 226 + self.match(RelayParser.T__3) + self.state = 227 + self.ident() + self.state = 232 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 228 + self.match(RelayParser.T__2) + self.state = 229 + self.ident() + self.state = 234 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 235 + self.match(RelayParser.T__4) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Type_Context(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_type_ + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class IntTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIntType" ): + return visitor.visitIntType(self) + else: + return visitor.visitChildren(self) + + + class TupleTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def type_(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTupleType" ): + return visitor.visitTupleType(self) + else: + return visitor.visitChildren(self) + + + class TypeIdentTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def typeIdent(self): + return self.getTypedRuleContext(RelayParser.TypeIdentContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeIdentType" ): + return visitor.visitTypeIdentType(self) + else: + return visitor.visitChildren(self) + + + class IncompleteTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIncompleteType" ): + return visitor.visitIncompleteType(self) + else: + return visitor.visitChildren(self) + + + class TensorTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def shapeSeq(self): + return self.getTypedRuleContext(RelayParser.ShapeSeqContext,0) + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTensorType" ): + return visitor.visitTensorType(self) + else: + return visitor.visitChildren(self) + + + class FuncTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def type_(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFuncType" ): + return visitor.visitFuncType(self) + else: + return visitor.visitChildren(self) + + + + def type_(self): + + localctx = RelayParser.Type_Context(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_type_) + self._la = 0 # Token type + try: + self.state = 284 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 239 + self.match(RelayParser.T__0) + self.state = 240 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 241 + self.match(RelayParser.T__0) + self.state = 242 + self.type_() + self.state = 243 + self.match(RelayParser.T__2) + self.state = 244 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.match(RelayParser.T__0) + self.state = 247 + self.type_() + self.state = 250 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 248 + self.match(RelayParser.T__2) + self.state = 249 + self.type_() + self.state = 252 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 254 + self.match(RelayParser.T__1) + pass + + elif la_ == 4: + localctx = RelayParser.TypeIdentTypeContext(self, localctx) + self.enterOuterAlt(localctx, 4) + self.state = 256 + self.typeIdent() + pass + + elif la_ == 5: + localctx = RelayParser.TensorTypeContext(self, localctx) + self.enterOuterAlt(localctx, 5) + self.state = 257 + self.match(RelayParser.T__16) + self.state = 258 + self.match(RelayParser.T__3) + self.state = 259 + self.shapeSeq() + self.state = 260 + self.match(RelayParser.T__2) + self.state = 261 + self.type_() + self.state = 262 + self.match(RelayParser.T__4) + pass + + elif la_ == 6: + localctx = RelayParser.FuncTypeContext(self, localctx) + self.enterOuterAlt(localctx, 6) + self.state = 264 + self.match(RelayParser.T__12) + self.state = 266 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 265 + self.typeParamSeq() + + + self.state = 268 + self.match(RelayParser.T__0) + self.state = 277 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__12) | (1 << RelayParser.T__16) | (1 << RelayParser.T__17) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 269 + self.type_() + self.state = 274 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 270 + self.match(RelayParser.T__2) + self.state = 271 + self.type_() + self.state = 276 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 279 + self.match(RelayParser.T__1) + self.state = 280 + self.match(RelayParser.T__13) + self.state = 281 + self.type_() + pass + + elif la_ == 7: + localctx = RelayParser.IncompleteTypeContext(self, localctx) + self.enterOuterAlt(localctx, 7) + self.state = 282 + self.match(RelayParser.T__17) + pass + + elif la_ == 8: + localctx = RelayParser.IntTypeContext(self, localctx) + self.enterOuterAlt(localctx, 8) + self.state = 283 + self.match(RelayParser.NAT) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeSeqContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def shape(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ShapeContext) + else: + return self.getTypedRuleContext(RelayParser.ShapeContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_shapeSeq + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShapeSeq" ): + return visitor.visitShapeSeq(self) + else: + return visitor.visitChildren(self) + + + + + def shapeSeq(self): + + localctx = RelayParser.ShapeSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_shapeSeq) + self._la = 0 # Token type + try: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 286 + self.match(RelayParser.T__0) + self.state = 287 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 288 + self.match(RelayParser.T__0) + self.state = 289 + self.shape() + self.state = 290 + self.match(RelayParser.T__2) + self.state = 291 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 293 + self.match(RelayParser.T__0) + self.state = 294 + self.shape() + self.state = 297 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 295 + self.match(RelayParser.T__2) + self.state = 296 + self.shape() + self.state = 299 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 301 + self.match(RelayParser.T__1) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_shape + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class ParensShapeContext(ShapeContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext + super().__init__(parser) + self.copyFrom(ctx) + + def shape(self): + return self.getTypedRuleContext(RelayParser.ShapeContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitParensShape" ): + return visitor.visitParensShape(self) + else: + return visitor.visitChildren(self) + + + class IntShapeContext(ShapeContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext + super().__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIntShape" ): + return visitor.visitIntShape(self) + else: + return visitor.visitChildren(self) + + + + def shape(self): + + localctx = RelayParser.ShapeContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_shape) + try: + self.state = 310 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.T__0]: + localctx = RelayParser.ParensShapeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 305 + self.match(RelayParser.T__0) + self.state = 306 + self.shape() + self.state = 307 + self.match(RelayParser.T__1) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.IntShapeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 309 + self.match(RelayParser.NAT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeIdentContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_typeIdent + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeIdent" ): + return visitor.visitTypeIdent(self) + else: + return visitor.visitChildren(self) + + + + + def typeIdent(self): + + localctx = RelayParser.TypeIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_typeIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 312 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BodyContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_body + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitBody" ): + return visitor.visitBody(self) + else: + return visitor.visitChildren(self) + + + + + def body(self): + + localctx = RelayParser.BodyContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_body) + try: + self.enterOuterAlt(localctx, 1) + self.state = 314 + self.match(RelayParser.T__10) + self.state = 315 + self.expr(0) + self.state = 316 + self.match(RelayParser.T__11) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ScalarContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_scalar + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class ScalarFloatContext(ScalarContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext + super().__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(RelayParser.FLOAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarFloat" ): + return visitor.visitScalarFloat(self) + else: + return visitor.visitChildren(self) + + + class ScalarBoolContext(ScalarContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext + super().__init__(parser) + self.copyFrom(ctx) + + def BOOL_LIT(self): + return self.getToken(RelayParser.BOOL_LIT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarBool" ): + return visitor.visitScalarBool(self) + else: + return visitor.visitChildren(self) + + + class ScalarIntContext(ScalarContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext + super().__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarInt" ): + return visitor.visitScalarInt(self) + else: + return visitor.visitChildren(self) + + + + def scalar(self): + + localctx = RelayParser.ScalarContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_scalar) + try: + self.state = 321 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.FLOAT]: + localctx = RelayParser.ScalarFloatContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.match(RelayParser.FLOAT) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.ScalarIntContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 319 + self.match(RelayParser.NAT) + pass + elif token in [RelayParser.BOOL_LIT]: + localctx = RelayParser.ScalarBoolContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 320 + self.match(RelayParser.BOOL_LIT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IdentContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def opIdent(self): + return self.getTypedRuleContext(RelayParser.OpIdentContext,0) + + + def GLOBAL_VAR(self): + return self.getToken(RelayParser.GLOBAL_VAR, 0) + + def LOCAL_VAR(self): + return self.getToken(RelayParser.LOCAL_VAR, 0) + + def GRAPH_VAR(self): + return self.getToken(RelayParser.GRAPH_VAR, 0) + + def getRuleIndex(self): + return RelayParser.RULE_ident + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIdent" ): + return visitor.visitIdent(self) + else: + return visitor.visitChildren(self) + + + + + def ident(self): + + localctx = RelayParser.IdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_ident) + try: + self.state = 327 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.CNAME]: + self.enterOuterAlt(localctx, 1) + self.state = 323 + self.opIdent() + pass + elif token in [RelayParser.GLOBAL_VAR]: + self.enterOuterAlt(localctx, 2) + self.state = 324 + self.match(RelayParser.GLOBAL_VAR) + pass + elif token in [RelayParser.LOCAL_VAR]: + self.enterOuterAlt(localctx, 3) + self.state = 325 + self.match(RelayParser.LOCAL_VAR) + pass + elif token in [RelayParser.GRAPH_VAR]: + self.enterOuterAlt(localctx, 4) + self.state = 326 + self.match(RelayParser.GRAPH_VAR) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[2] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 13) + + + if predIndex == 4: + return self.precpred(self._ctx, 4) + + + if predIndex == 5: + return self.precpred(self._ctx, 18) + + + + + diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py new file mode 100644 index 000000000000..64308dca1a3a --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayVisitor.py @@ -0,0 +1,198 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +from antlr4 import * +if __name__ is not None and "." in __name__: + from .RelayParser import RelayParser +else: + from RelayParser import RelayParser + +# This class defines a complete generic visitor for a parse tree produced by RelayParser. + +class RelayVisitor(ParseTreeVisitor): + + # Visit a parse tree produced by RelayParser#opIdent. + def visitOpIdent(self, ctx:RelayParser.OpIdentContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#prog. + def visitProg(self, ctx:RelayParser.ProgContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#identExpr. + def visitIdentExpr(self, ctx:RelayParser.IdentExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#call. + def visitCall(self, ctx:RelayParser.CallContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#neg. + def visitNeg(self, ctx:RelayParser.NegContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tuple. + def visitTuple(self, ctx:RelayParser.TupleContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parens. + def visitParens(self, ctx:RelayParser.ParensContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcExpr. + def visitFuncExpr(self, ctx:RelayParser.FuncExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarExpr. + def visitScalarExpr(self, ctx:RelayParser.ScalarExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#let. + def visitLet(self, ctx:RelayParser.LetContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensor. + def visitTensor(self, ctx:RelayParser.TensorContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ifElse. + def visitIfElse(self, ctx:RelayParser.IfElseContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#graph. + def visitGraph(self, ctx:RelayParser.GraphContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#binOp. + def visitBinOp(self, ctx:RelayParser.BinOpContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#func. + def visitFunc(self, ctx:RelayParser.FuncContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#defn. + def visitDefn(self, ctx:RelayParser.DefnContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#argList. + def visitArgList(self, ctx:RelayParser.ArgListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#varList. + def visitVarList(self, ctx:RelayParser.VarListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#var. + def visitVar(self, ctx:RelayParser.VarContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attrList. + def visitAttrList(self, ctx:RelayParser.AttrListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attr. + def visitAttr(self, ctx:RelayParser.AttrContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeParamSeq. + def visitTypeParamSeq(self, ctx:RelayParser.TypeParamSeqContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tupleType. + def visitTupleType(self, ctx:RelayParser.TupleTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdentType. + def visitTypeIdentType(self, ctx:RelayParser.TypeIdentTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensorType. + def visitTensorType(self, ctx:RelayParser.TensorTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcType. + def visitFuncType(self, ctx:RelayParser.FuncTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#incompleteType. + def visitIncompleteType(self, ctx:RelayParser.IncompleteTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intType. + def visitIntType(self, ctx:RelayParser.IntTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#shapeSeq. + def visitShapeSeq(self, ctx:RelayParser.ShapeSeqContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parensShape. + def visitParensShape(self, ctx:RelayParser.ParensShapeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intShape. + def visitIntShape(self, ctx:RelayParser.IntShapeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdent. + def visitTypeIdent(self, ctx:RelayParser.TypeIdentContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#body. + def visitBody(self, ctx:RelayParser.BodyContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarFloat. + def visitScalarFloat(self, ctx:RelayParser.ScalarFloatContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarInt. + def visitScalarInt(self, ctx:RelayParser.ScalarIntContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarBool. + def visitScalarBool(self, ctx:RelayParser.ScalarBoolContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ident. + def visitIdent(self, ctx:RelayParser.IdentContext): + return self.visitChildren(ctx) + + + +del RelayParser \ No newline at end of file diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 2d27b7b53f89..9218cae3de66 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -18,19 +18,6 @@ from __future__ import absolute_import from .. import register_func -def enabled(): - """Checks whether the parser is enabled, this allows users to - optionally support building the parser. - - We use this check before importing the parser. - """ - try: - # pylint: disable=unused-variable - from tvm.relay import _parser - return True - # pylint: disable=broad-except - except Exception: - return False @register_func("relay.fromtext") def fromtext(data, source_name=None): diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 92647e5b14b4..c801e490d4cf 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -15,12 +15,16 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name -"""Adds certain standard global functions and ADT definitions to the module.""" +"""A prelude containing useful global functions and ADT definitions.""" +import os from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard +from .parser import fromtext + +__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) class Prelude: """Contains standard definitions.""" @@ -451,35 +455,6 @@ def define_tree_size(self): Match(t, [rose_case]), scalar_type('int32'), [a]) - def define_id(self): - """Defines a function that return its argument. - - Signature: fn(x : a) -> a - """ - self.id = GlobalVar("id") - a = TypeVar("a") - x = Var("x", a) - self.mod[self.id] = Function([x], x, a, [a]) - - - def define_compose(self): - """Defines a function that composes two function. - - Signature: fn(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c - """ - self.compose = GlobalVar("compose") - a = TypeVar("a") - b = TypeVar("b") - c = TypeVar("c") - f = Var("f", FuncType([b], c)) - g = Var("g", FuncType([a], b)) - x = Var("x") - self.mod[self.compose] = Function([f, g], - Function([x], f(g(x))), - FuncType([a], c), - [a, b, c]) - - def define_iterate(self): """Defines a function that take a number n and a function f; returns a closure that takes an argument and applies f @@ -500,9 +475,23 @@ def define_iterate(self): FuncType([a], a), [a]) + def load_prelude(self): + """ + Parses the portions of the Prelude written in Relay's text format and adds + them to the module. + """ + prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly") + with open(prelude_file) as prelude: + prelude = fromtext(prelude.read()) + self.mod.update(prelude) + self.id = self.mod["id"] + self.compose = self.mod["compose"] + def __init__(self, mod): self.mod = mod + self.load_prelude() + self.define_list_adt() self.define_list_hd() self.define_list_tl() @@ -530,6 +519,4 @@ def __init__(self, mod): self.define_tree_map() self.define_tree_size() - self.define_id() - self.define_compose() self.define_iterate() diff --git a/python/tvm/relay/prelude.rly b/python/tvm/relay/prelude.rly new file mode 100644 index 000000000000..35c794a6d479 --- /dev/null +++ b/python/tvm/relay/prelude.rly @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +v0.0.2 + +def @id[a](%x: a) -> a { + %x +} + +def @compose[a, b, c](%f: fn(b) -> c, %g: fn(a) -> b) { + fn (%x: a) -> c { + %f(%g(%x)) + } +} diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 69322f5efaaa..d33cede97508 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -41,6 +41,8 @@ "pxi", "pyd", "pyx", + # relay text format + "rly", # configurations "mk", "in", @@ -66,11 +68,15 @@ "sbt", "properties", "v", + # generated parser + "interp", + "tokens" } # List of file names allowed ALLOW_FILE_NAME = { ".gitignore", + ".gitattributes", "README", "Makefile", "Doxyfile", @@ -155,7 +161,7 @@ def main(): report += "\nFound %d files that are now allowed\n" % len(error_list) report += ("We do not check in binary files into the repo.\n" "If necessary, please discuss with committers and" - "modify tests/scripts/check_file_type.py to enable the file you need.\n") + "modify tests/lint/check_file_type.py to enable the file you need.\n") sys.stderr.write(report) sys.stderr.flush() sys.exit(-1) diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 72faa1112c94..934e65c2cad2 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -23,6 +23,8 @@ .*\.csv .*\.mk .*\.log +.*\.interp +.*\.tokens # Generated modules .*\.egg-info @@ -35,10 +37,16 @@ _build .*~ \#..*\# +# Relay parser +RelayLexer.py +RelayParser.py +RelayVisitor.py + # Specific files package-list MANIFEST .gitignore +.gitattributes .gitmodules .clang-format .bash_history diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index d3c30be2a953..79b010ba0cb0 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -16,19 +16,14 @@ # under the License. import tvm from tvm import relay -from tvm.relay.parser import enabled from tvm.relay.ir_pass import alpha_equal -from nose import SkipTest from nose.tools import nottest, raises from numpy import isclose from typing import Union from functools import wraps -if enabled(): - raises_parse_error = raises(tvm._ffi.base.TVMError) -else: - raises_parse_error = lambda x: x +raises_parse_error = raises(tvm._ffi.base.TVMError) -SEMVER = "v0.0.1" +SEMVER = "v0.0.2" BINARY_OPS = { "*": relay.multiply, @@ -65,9 +60,12 @@ "float16x4", } +def parse_text(code): + return relay.fromtext(SEMVER + "\n" + code) + def parses_as(code, expr): # type: (str, relay.Expr) -> bool - return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr) + return alpha_equal(parse_text(code), expr) def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) @@ -83,17 +81,7 @@ def get_scalar(x): UNIT = relay.Tuple([]) -# decorator to determine if parser is enabled -def if_parser_enabled(func): - # https://stackoverflow.com/q/7727678 - @wraps(func) - def wrapper(): - if not enabled(): - raise SkipTest("ANTLR is not installed!") - func() - return wrapper - -@if_parser_enabled + def test_comments(): assert parses_as( """ @@ -113,46 +101,46 @@ def test_comments(): UNIT ) -@if_parser_enabled + def test_int_literal(): - assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant) - assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray) + assert isinstance(parse_text("1"), relay.Constant) + assert isinstance(parse_text("1").data, tvm.ndarray.NDArray) + + assert get_scalar(parse_text("1")) == 1 + assert get_scalar(parse_text("10")) == 10 + assert get_scalar(parse_text("0")) == 0 + assert get_scalar(parse_text("-100")) == -100 + assert get_scalar(parse_text("-05")) == -5 - assert get_scalar(relay.fromtext(SEMVER+"1")) == 1 - assert get_scalar(relay.fromtext(SEMVER+"10")) == 10 - assert get_scalar(relay.fromtext(SEMVER+"0")) == 0 - assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100 - assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5 -@if_parser_enabled def test_float_literal(): - assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0 - assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667) - assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0 - assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0 + assert get_scalar(parse_text("1.0")) == 1.0 + assert isclose(get_scalar(parse_text("1.56667")), 1.56667) + assert get_scalar(parse_text("0.0")) == 0.0 + assert get_scalar(parse_text("-10.0")) == -10.0 # scientific notation - assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1) - assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1 - assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1) - assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1 - assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1) - assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1 - assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1) - assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1 - -@if_parser_enabled + assert isclose(get_scalar(parse_text("1e-1")), 1e-1) + assert get_scalar(parse_text("1e+1")) == 1e+1 + assert isclose(get_scalar(parse_text("1E-1")), 1E-1) + assert get_scalar(parse_text("1E+1")) == 1E+1 + assert isclose(get_scalar(parse_text("1.0e-1")), 1.0e-1) + assert get_scalar(parse_text("1.0e+1")) == 1.0e+1 + assert isclose(get_scalar(parse_text("1.0E-1")), 1.0E-1) + assert get_scalar(parse_text("1.0E+1")) == 1.0E+1 + + def test_bool_literal(): - assert get_scalar(relay.fromtext(SEMVER+"True")) == True - assert get_scalar(relay.fromtext(SEMVER+"False")) == False + assert get_scalar(parse_text("True")) == True + assert get_scalar(parse_text("False")) == False + -@if_parser_enabled def test_negative(): - assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call) - assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10 - assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10 + assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) + assert get_scalar(parse_text("--10")) == 10 + assert get_scalar(parse_text("---10")) == -10 + -@if_parser_enabled def test_bin_op(): for bin_op in BINARY_OPS.keys(): assert parses_as( @@ -160,18 +148,18 @@ def test_bin_op(): BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) -@if_parser_enabled + def test_parens(): - assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1")) - assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)")) + assert alpha_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1")) + assert not alpha_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)")) + -@if_parser_enabled def test_op_assoc(): - assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1")) - assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))")) + assert alpha_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) + @nottest -@if_parser_enabled def test_vars(): # temp vars won't work b/c they start with a digit # # temp var @@ -180,21 +168,21 @@ def test_vars(): # assert temp_var.name == "1" # var - var = relay.fromtext(SEMVER+"let %foo = (); %foo") + var = parse_text("let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var - global_var = relay.fromtext(SEMVER+"@foo") + global_var = parse_text("@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id - op = relay.fromtext(SEMVER+"foo") + op = parse_text("foo") assert isinstance(op, relay.Op) assert op.name == "foo" -@if_parser_enabled + def test_let(): assert parses_as( "let %x = 1; ()", @@ -222,7 +210,7 @@ def test_let(): ) ) -@if_parser_enabled + def test_seq(): assert parses_as( "(); ()", @@ -241,7 +229,7 @@ def test_seq(): ) ) -@if_parser_enabled + def test_graph(): assert parses_as( "%0 = (); %1 = 1; (%0, %0, %1)", @@ -253,22 +241,22 @@ def test_graph(): relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) ) + @raises_parse_error -@if_parser_enabled def test_graph_wrong_order(): - relay.fromtext(SEMVER+"%1 = (); %1") + parse_text("%1 = (); %1") + @raises_parse_error -@if_parser_enabled def test_let_global_var(): - relay.fromtext(SEMVER+"let @x = 1; ()") + parse_text("let @x = 1; ()") + @raises_parse_error -@if_parser_enabled def test_let_op(): - relay.fromtext(SEMVER+"let x = 1; ()") + parse_text("let x = 1; ()") + -@if_parser_enabled def test_tuple(): assert parses_as("()", relay.Tuple([])) @@ -278,7 +266,7 @@ def test_tuple(): assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) -@if_parser_enabled + def test_func(): # 0 args assert parses_as( @@ -330,8 +318,8 @@ def test_func(): relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5))) ) + # TODO(@jmp): Crashes if %x isn't annnotated. -@if_parser_enabled def test_defn(): id_defn = relay.fromtext( SEMVER+ @@ -342,7 +330,7 @@ def @id(%x: int32) -> int32 { """) assert isinstance(id_defn, relay.Module) -@if_parser_enabled + def test_recursive_call(): id_defn = relay.fromtext( SEMVER+ @@ -353,7 +341,7 @@ def @id(%x: int32) -> int32 { """) assert isinstance(id_defn, relay.Module) -@if_parser_enabled + def test_ifelse(): assert parses_as( """ @@ -370,8 +358,8 @@ def test_ifelse(): ) ) + @raises_parse_error -@if_parser_enabled def test_ifelse_scope(): relay.fromtext( SEMVER+ @@ -385,7 +373,7 @@ def test_ifelse_scope(): """ ) -@if_parser_enabled + def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") @@ -509,7 +497,7 @@ def test_call(): # Types -@if_parser_enabled + def test_incomplete_type(): assert parses_as( "let %_ : _ = (); ()", @@ -520,17 +508,17 @@ def test_incomplete_type(): ) ) -@if_parser_enabled + def test_builtin_types(): for builtin_type in TYPES: - relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type)) + parse_text("let %_ : {} = (); ()".format(builtin_type)) + @nottest -@if_parser_enabled def test_call_type(): assert False -@if_parser_enabled + def test_tensor_type(): assert parses_as( "let %_ : Tensor[(), float32] = (); ()", @@ -559,7 +547,7 @@ def test_tensor_type(): ) ) -@if_parser_enabled + def test_function_type(): assert parses_as( """ @@ -594,7 +582,7 @@ def test_function_type(): ) ) -@if_parser_enabled + def test_tuple_type(): assert parses_as( """ From f1d3d6fb61fefe55b96b0bca22a94edba3cb1d5c Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 10 Jun 2019 21:26:15 -0700 Subject: [PATCH 113/176] Add LOGISTIC operator to relay tflite frontend (#3313) --- python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3a13473202a3..eb9e742ff85f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -67,6 +67,7 @@ def __init__(self, model, subgraph, exp_tab): 'MUL': self.convert_mul, 'FULLY_CONNECTED': self.convert_fully_connected, 'PAD': self.convert_pad, + 'LOGISTIC': self.convert_logistic, } def check_unsupported_ops(self): @@ -218,6 +219,23 @@ def convert_reshape(self, op): return out + def convert_logistic(self, op): + """Convert TFLite LOGISTIC""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + out = _op.sigmoid(in_expr) + return out + def convert_softmax(self, op): """Convert TFLite softmax""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7da2b851bb3f..5c2e3afb5a0d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -423,6 +423,22 @@ def test_forward_pad(): np.array([[1, 1], [2, 2]], dtype=np.int32)]) +####################################################################### +# Logistic +# -------- + +def _test_logistic(data): + """ One iteration of LOGISTIC """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = math_ops.sigmoid(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_logistic(): + """ LOGISTIC """ + _test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + + ####################################################################### # Softmax # ------- @@ -563,6 +579,7 @@ def test_forward_inception_v4_net(): # NN test_forward_convolution() + test_forward_logistic() test_forward_pooling() test_forward_softmax() test_forward_fully_connected() From 5aa821258d08eb5892fcdb3e803c16b17034bd74 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Tue, 11 Jun 2019 18:54:04 +0100 Subject: [PATCH 114/176] [CI] Clarify RAT exclude patterns. (#3328) --- tests/lint/rat-excludes | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 934e65c2cad2..2e7a23a1ea4f 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -1,6 +1,8 @@ # subdirectories -3rdparty/* -.github/* +3rdparty +.github +jvm +tutorials # Binary or data files .*\.css @@ -28,7 +30,7 @@ # Generated modules .*\.egg-info -.*gen_modules/* +.*gen_modules .*doxygen core.cpp build @@ -36,6 +38,7 @@ _static _build .*~ \#..*\# +\.#.* # Relay parser RelayLexer.py From 278d81a9dab842a871cdf423b8db7668ef7670c1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 11 Jun 2019 10:55:24 -0700 Subject: [PATCH 115/176] [RELAY] Pass infra cleanup (#3336) --- include/tvm/relay/transform.h | 5 +- python/tvm/relay/transform.py | 314 ++++++++++++------------ src/relay/pass/pass_manager.cc | 8 +- tests/python/relay/test_pass_manager.py | 7 + 4 files changed, 169 insertions(+), 165 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 793bc981ea61..f579f1c7ba91 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode { v->Visit("required", &required); } - TVM_DLL static PassInfo make(int opt_level, std::string name, + TVM_DLL static PassInfo make(int opt_level, + std::string name, tvm::Array required); static constexpr const char* _type_key = "relay.PassInfo"; @@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference(); * type information filled in, as well as it's checked type field * populated with the result type. * - * \return The pass. + * \return The pass. */ TVM_DLL Pass InferType(); diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 38079b010e7d..b76c2361605c 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -14,13 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return -# pylint: disable=unidiomatic-typecheck # pylint: disable=invalid-name """ -This file contains the pass manager for Relay which exposes different -granularity of interfaces for users to implement and use passes more -conveniently. +Relay pass transformation infrastructure. """ import types @@ -39,19 +35,19 @@ class PassInfo(RelayNode): Parameters ---------- - name : str - The pass name. - opt_level : int The optimization level of this pass. + name : str + The pass name. + required : List[str] The list of passes that are required by a certain pass. """ - def __init__(self, name, opt_level, required=None): - self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level, - required) + def __init__(self, opt_level, name, required=None): + self.__init_handle_by_constructor__( + _transform.PassInfo, opt_level, name, required) @register_relay_node @@ -194,7 +190,7 @@ class ModulePass(Pass): `module_pass`, because the design of the `module_pass` API is flexible enough to handle the creation of a module pass in different manners. In addition, all members of a module pass can be accessed from the base class. - The same rule applies to FunctionPass and Sequential as well. + The same rule applies to FunctionPass as well. """ @@ -250,153 +246,6 @@ def __init__(self, passes, opt_level, name, required) -def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a module pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created module level pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the module pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_module_pass : Union[Callable, ModulePass] - The callable that will create a module pass is returned when - pass_func is not passed in. Otherwise, a ModulePass object will be - directly created. - - Examples - -------- - The following code creates a module level pass and adds an abs function to - the module. - - .. code-block:: python - - @relay.transform.module_pass(opt_level=2) - def transform(mod, ctx): - tp = relay.TensorType((10,), "float32") - x = relay.var("x", tp) - gv = relay.GlobalVar("var") - func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) - new_mod.update(mod) - return new_mod - - module_pass = transform - assert isinstance(module_pass, transform.ModulePass) - assert module_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = module_pass(m) - # Now a function abs should be added to the module m. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the module pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_module_pass(pass_func): - """Internal function that creates a module pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _transform.CreateModulePass( - pass_func, opt_level, name if name else pass_func.__name__, - required) - - if pass_func: - return create_module_pass(pass_func) - return create_module_pass - - -def function_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a function pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created function pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the function pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_function_pass : Union[Callable, FunctionPass] - The callable that will create a function pass is returned when - pass_func is not passed in. Otherwise, a FunctionPass object will be - created. - - Examples - -------- - The following code creates a function level pass that performs constant - folding. - - .. code-block:: python - - @relay.transform.function_pass(opt_level=2) - def transform(func, ctx): - return ir_pass.fold_constant(func) - - function_pass = transform - assert isinstance(function_pass, transform.FunctionPass) - assert function_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = function_pass(m) - # Now constant folding should have been applied to every function in - # the provided module m. And the updated module will be returned. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the funtion pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_function_pass(pass_func): - """Internal function that creates a function pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _transform.CreateFunctionPass( - pass_func, opt_level, name if name else pass_func.__name__, - required) - - if pass_func: - return create_function_pass(pass_func) - return create_function_pass - - def InferType(): """Infer the type of an expr. @@ -593,3 +442,150 @@ def PartialEvaluate(): The registered pass that performs partial evaluation on an expression. """ return _transform.PartialEvaluate() + + +def module_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a module pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created module level pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the module pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_module_pass : Union[Callable, ModulePass] + The callable that will create a module pass is returned when + pass_func is not passed in. Otherwise, a ModulePass object will be + directly created. + + Examples + -------- + The following code creates a module level pass and adds an abs function to + the module. + + .. code-block:: python + + @relay.transform.module_pass(opt_level=2) + def transform(mod, ctx): + tp = relay.TensorType((10,), "float32") + x = relay.var("x", tp) + gv = relay.GlobalVar("var") + func = relay.Function([x], relay.abs(x)) + new_mod = relay.Module({gv: func}) + new_mod.update(mod) + return new_mod + + module_pass = transform + assert isinstance(module_pass, transform.ModulePass) + assert module_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = module_pass(m) + # Now a function abs should be added to the module m. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the module pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_module_pass(pass_func): + """Internal function that creates a module pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + fname = name if name else pass_func.__name__ + info = PassInfo(opt_level, fname, required) + return _transform.MakeModulePass(pass_func, info) + + if pass_func: + return create_module_pass(pass_func) + return create_module_pass + + +def function_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a function pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + The callable that will create a function pass is returned when + pass_func is not passed in. Otherwise, a FunctionPass object will be + created. + + Examples + -------- + The following code creates a function level pass that performs constant + folding. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=2) + def transform(func, ctx): + return ir_pass.fold_constant(func) + + function_pass = transform + assert isinstance(function_pass, transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now constant folding should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the funtion pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_function_pass(pass_func): + """Internal function that creates a function pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + fname = name if name else pass_func.__name__ + info = PassInfo(opt_level, fname, required) + return _transform.MakeFunctionPass(pass_func, info) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 782bb6a5980f..500bdce742a0 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_API("relay._transform.CreateModulePass") -.set_body_typed(CreateModulePass); +TVM_REGISTER_API("relay._transform.MakeModulePass") +.set_body_typed(ModulePassNode::make); TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_API("relay._transform.CreateFunctionPass") -.set_body_typed(CreateFunctionPass); +TVM_REGISTER_API("relay._transform.MakeFunctionPass") +.set_body_typed(FunctionPassNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 7fdef3fa8b9c..7505aa9ab981 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -259,6 +259,12 @@ def test_pass_run(): test_pass_run() +def test_pass_info(): + info = relay.transform.PassInfo(opt_level=1, name="xyz") + assert info.opt_level == 1 + assert info.name == "xyz" + + def test_sequential_pass(): shape = (10, ) dtype = 'float32' @@ -449,3 +455,4 @@ def expected(): test_function_pass() test_sequential_pass() test_sequential_with_scoping() + test_pass_info() From 9f94927dccaebb7b871144b38ab459ef4db51d8a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 11 Jun 2019 10:55:37 -0700 Subject: [PATCH 116/176] [CI] separate out legacy as a stage (#3337) --- Jenkinsfile | 11 ++++++ tests/scripts/task_python_frontend.sh | 29 +------------- tests/scripts/task_python_legacy.sh | 55 +++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 28 deletions(-) create mode 100755 tests/scripts/task_python_legacy.sh diff --git a/Jenkinsfile b/Jenkinsfile index ea468c24ed53..bdbb3ecb6427 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -281,6 +281,17 @@ stage('Integration Test') { } } }, + 'legacy: GPU': { + node('GPU') { + ws('workspace/tvm/legacy-python-gpu') { + init_git() + unpack_lib('gpu', tvm_multilib) + timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_legacy.sh" + } + } + } + }, 'docs: GPU': { node('GPU') { ws('workspace/tvm/docs-python-gpu') { diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 609b00149bad..9985d4ab7821 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -29,33 +29,6 @@ make cython3 echo "Running relay TFLite frontend test..." python3 -m nose -v tests/python/frontend/tflite -echo "Running nnvm unittest..." -python3 -m nose -v nnvm/tests/python/unittest - -echo "Running nnvm compiler test..." -python3 -m nose -v nnvm/tests/python/compiler - -echo "Running nnvm ONNX frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/onnx - -echo "Running nnvm MXNet frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/mxnet - -echo "Running nnvm Keras frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/keras - -echo "Running nnvm Tensorflow frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/tensorflow - -echo "Running nnvm CoreML frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/coreml - -echo "Running nnvm Caffe2 frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/caffe2 - -echo "Running nnvm DarkNet frontend test..." -python3 -m nose -v nnvm/tests/python/frontend/darknet || exit -1 - echo "Running relay MXNet frontend test..." python3 -m nose -v tests/python/frontend/mxnet @@ -78,4 +51,4 @@ echo "Running relay caffe2 frontend test..." python3 -m nose -v tests/python/frontend/caffe2 echo "Running relay DarkNet frontend test..." -python3 -m nose -v tests/python/frontend/darknet || exit -1 +python3 -m nose -v tests/python/frontend/darknet diff --git a/tests/scripts/task_python_legacy.sh b/tests/scripts/task_python_legacy.sh new file mode 100755 index 000000000000..df1615bb8550 --- /dev/null +++ b/tests/scripts/task_python_legacy.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test cases for legacy code, will be deprecated in the future. +set -e +set -u + +export PYTHONPATH=nnvm/python:python:topi/python +export OMP_NUM_THREADS=1 + +# Rebuild cython +make cython3 + +echo "Running nnvm unittest..." +python3 -m nose -v nnvm/tests/python/unittest + + +echo "Running nnvm compiler test..." +python3 -m nose -v nnvm/tests/python/compiler + +echo "Running nnvm ONNX frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/onnx + +echo "Running nnvm MXNet frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/mxnet + +echo "Running nnvm DarkNet frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/darknet + +echo "Running nnvm Keras frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/keras + +echo "Running nnvm Tensorflow frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/tensorflow + +echo "Running nnvm CoreML frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/coreml + +echo "Running nnvm Caffe2 frontend test..." +python3 -m nose -v nnvm/tests/python/frontend/caffe2 From 630a0aeeeb1b79f8769b4b16d52d69b96af726af Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Tue, 11 Jun 2019 16:32:12 -0700 Subject: [PATCH 117/176] [Topi] Fast mode in take op (#3325) --- include/tvm/relay/attrs/transform.h | 3 ++- python/tvm/relay/op/transform.py | 3 ++- tests/python/relay/test_op_level3.py | 6 +++++- topi/include/topi/transform.h | 26 ++++++++++++++++++++++++ topi/python/topi/transform.py | 1 + topi/tests/python/test_topi_transform.py | 9 ++++++-- 6 files changed, 43 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 1b82412d0482..65febae5e7ff 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(mode).set_default("clip") .describe("Specify how out-of-bound indices will behave." "clip - clip to the range (default)" - "wrap - wrap around the indices"); + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 02fd4924b804..dce2258946cd 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"): the flattened input array is used. mode : str, optional - Specifies how out-of-bound indices will behave [clip, wrap]. + Specifies how out-of-bound indices will behave [clip, wrap, fast]. clip: clip to the range (default). wrap: wrap around the indices. + fast: no clip or wrap around (user must make sure indices are in-bound). Returns ------- diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 15cb3265b679..a878d79e678d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -269,7 +269,8 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): func = relay.Function([x, indices], z) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) - ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode) + np_mode = "raise" if mode == "fast" else mode + ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: @@ -291,6 +292,9 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=1) verify_take((3,4), [-1, 2], axis=1, mode="wrap") + verify_take((3,3,3), [[11,25]], mode="fast") + verify_take((3,4), [0, 2], axis=0, mode="fast") + verify_take((3,4), [0, 2], axis=1, mode="fast") def test_split_infer_type(): diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 4dba4eade6bd..c992be6b0022 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a, auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); }, name, tag); + } else if (mode == "fast") { + LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " + "Make sure input indices are in bound"; + return compute( + out_shape, [&](const Array& out_index) { + return a(UnravelIndex(indices(out_index), a_shape)); + }, name, tag); } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { @@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a, } return a(real_indices); }, name, tag); + } else if (mode == "fast") { + LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " + "Make sure input indices are in bound"; + return compute( + out_shape, [&](const Array& out_index) { + Array indices_position; + for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + indices_position.push_back(out_index[j]); + } + Array real_indices; + for (size_t j = 0; j < static_cast(axis); ++j) { + real_indices.push_back(out_index[j]); + } + real_indices.push_back(indices(indices_position)); + for (size_t j = axis + indices_len; j < out_index.size(); ++j) { + real_indices.push_back(out_index[j]); + } + return a(real_indices); + }, name, tag); } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 04af1513576b..3d7293edc6ff 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"): Specifies how out-of-bound indices will behave. clip - clip to the range (default) wrap - wrap around the indices + fast - no clip or wrap around (user must make sure indices are in-bound) Returns ------- diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index d29fb64544b9..5682fde69372 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -275,9 +275,11 @@ def check_device(device): data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: - out_npys = np.take(data_npy, indices_src, mode=mode) + np_mode = "raise" if mode == "fast" else mode + out_npys = np.take(data_npy, indices_src, mode=np_mode) else: - out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode) + np_mode = "raise" if mode == "fast" else mode + out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) @@ -521,6 +523,9 @@ def test_take(): verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=1) verify_take((3,4), [-1, 2], axis=1, mode="wrap") + verify_take((3,3,3), [[11,25]], mode="fast") + verify_take((3,4), [0, 2], axis=0, mode="fast") + verify_take((3,4), [0, 2], axis=1, mode="fast") def test_gather_nd(): for indices_dtype in ['int32', 'float32']: From a0aeeb5362ccb613ecec3304842b6891f14afe89 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Tue, 11 Jun 2019 16:55:41 -0700 Subject: [PATCH 118/176] [VTA][TSIM] update app example (#3343) * add initial support to cycle counter to accelerator * remove prints from c * add event counter support to chisel tsim example * make it more readable * use a config class * update driver * add individual Makefile to chisel * add rule for installing vta package * add makefile for verilog backend * update drivers * update * rename * update README * put default sim back * set counter to zero --- vta/apps/tsim_example/CMakeLists.txt | 10 +- vta/apps/tsim_example/Makefile | 28 +++- vta/apps/tsim_example/README.md | 28 ++-- vta/apps/tsim_example/cmake/modules/hw.cmake | 152 ------------------ vta/apps/tsim_example/config/config.json | 7 - vta/apps/tsim_example/config/config.py | 61 ------- .../tsim_example/hardware/chisel/Makefile | 89 +++++++++- .../chisel/src/main/scala/accel/Accel.scala | 16 +- .../chisel/src/main/scala/accel/Compute.scala | 39 +++-- .../chisel/src/main/scala/accel/RegFile.scala | 72 +++++---- .../tsim_example/hardware/verilog/Makefile | 100 ++++++++++++ .../hardware/verilog/{ => src}/Accel.v | 83 ++++++---- .../hardware/verilog/{ => src}/Compute.v | 35 +++- .../hardware/verilog/{ => src}/RegFile.v | 73 ++++++--- .../hardware/verilog/{ => src}/TestAccel.v | 0 .../python/{tsim => accel}/__init__.py | 0 .../python/{tsim => accel}/driver.py | 30 ++-- vta/apps/tsim_example/src/driver.cc | 43 ++--- .../python/{add_by_one.py => chisel_accel.py} | 23 ++- .../python/verilog_accel.py} | 26 ++- .../src/main/resources/verilog/VTAHostDPI.v | 1 - vta/hardware/dpi/tsim_device.cc | 1 - 22 files changed, 511 insertions(+), 406 deletions(-) delete mode 100644 vta/apps/tsim_example/cmake/modules/hw.cmake delete mode 100644 vta/apps/tsim_example/config/config.json delete mode 100644 vta/apps/tsim_example/config/config.py create mode 100644 vta/apps/tsim_example/hardware/verilog/Makefile rename vta/apps/tsim_example/hardware/verilog/{ => src}/Accel.v (63%) rename vta/apps/tsim_example/hardware/verilog/{ => src}/Compute.v (85%) rename vta/apps/tsim_example/hardware/verilog/{ => src}/RegFile.v (72%) rename vta/apps/tsim_example/hardware/verilog/{ => src}/TestAccel.v (100%) rename vta/apps/tsim_example/python/{tsim => accel}/__init__.py (100%) rename vta/apps/tsim_example/python/{tsim => accel}/driver.py (62%) rename vta/apps/tsim_example/tests/python/{add_by_one.py => chisel_accel.py} (71%) rename vta/apps/tsim_example/{cmake/modules/sw.cmake => tests/python/verilog_accel.py} (56%) diff --git a/vta/apps/tsim_example/CMakeLists.txt b/vta/apps/tsim_example/CMakeLists.txt index 28cfded75823..56a5b9a3b228 100644 --- a/vta/apps/tsim_example/CMakeLists.txt +++ b/vta/apps/tsim_example/CMakeLists.txt @@ -34,6 +34,10 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}") endif() -# Module rules -include(cmake/modules/hw.cmake) -include(cmake/modules/sw.cmake) +file(GLOB TSIM_SW_SRC src/driver.cc) +add_library(sw SHARED ${TSIM_SW_SRC}) +target_include_directories(sw PRIVATE ${VTA_DIR}/include) + +if(APPLE) + set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") +endif(APPLE) diff --git a/vta/apps/tsim_example/Makefile b/vta/apps/tsim_example/Makefile index 2d7629ce12b2..ea8358b3dfe3 100644 --- a/vta/apps/tsim_example/Makefile +++ b/vta/apps/tsim_example/Makefile @@ -17,20 +17,32 @@ export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH) -BUILD_DIR = $(shell python3 config/config.py --get-build-name) +BUILD_NAME = build +build_dir = $(abspath .)/$(BUILD_NAME) -default: cmake run +default: verilog driver run_verilog +run_chisel: chisel driver + python3 tests/python/chisel_accel.py + .PHONY: cmake -cmake: | $(BUILD_DIR) - cd $(BUILD_DIR) && cmake .. && make +driver: | $(build_dir) + cd $(build_dir) && cmake .. && make -$(BUILD_DIR): +$(build_dir): mkdir -p $@ -run: - python3 tests/python/add_by_one.py | grep PASS +verilog: + make -C hardware/verilog + +chisel: + make -C hardware/chisel + +run_verilog: + python3 tests/python/verilog_accel.py clean: - -rm -rf $(BUILD_DIR) + -rm -rf $(build_dir) + make -C hardware/chisel clean + make -C hardware/verilog clean diff --git a/vta/apps/tsim_example/README.md b/vta/apps/tsim_example/README.md index 8f1230e9ba7e..56696fe533fc 100644 --- a/vta/apps/tsim_example/README.md +++ b/vta/apps/tsim_example/README.md @@ -49,29 +49,25 @@ sudo apt install verilator sbt ## Setup in TVM 1. Install `verilator` and `sbt` as described above -2. Change `TARGET` to `tsim` in `/tvm/vta/config/vta_config.json` -3. Build [tvm](https://docs.tvm.ai/install/from_source.html#build-the-shared-library) +2. Build [tvm](https://docs.tvm.ai/install/from_source.html#build-the-shared-library) ## How to run VTA TSIM examples -There are two sample VTA accelerators (add-by-one) designed in Chisel3 and Verilog to show how *TSIM* works. +There are two sample VTA accelerators, add-a-constant, designed in Chisel3 and Verilog to show how *TSIM* works. The default `TARGET` language for these two implementations is Verilog. The following instructions show how to run both of them: -* Verilog add-by-one +* Test Verilog backend * Go to `/vta/apps/tsim_example` - * Run `make` to build and run add-by-one test + * Run `make` -* Chisel3 add-by-one - * Open `/vta/apps/tsim_example/python/tsim/config.json` - * Change `TARGET` from `verilog` to `chisel` - * Go to `tvm/vta/apps/tsim_example` - * Run `make` to build and run add-by-one test +* Test Chisel3 backend + * Open `/vta/apps/tsim_example` + * Run `make run_chisel` * Some pointers - * Add-by-one test `/vta/apps/tsim_example/tests/python/add_by_one.py` - * Add-by-one accelerator in Verilog `/vta/apps/tsim_example/hardware/verilog` - * Add-by-one accelerator in Chisel3 `/vta/apps/tsim_example/hardware/chisel` - * Software driver that handles the accelerator `/vta/apps/tsim_example/src/driver.cc` - * Build cmake script for software library`/vta/apps/tsim_example/cmake/modules/sw.cmake` - * Build cmake script for hardware library`/vta/apps/tsim_example/cmake/modules/hw.cmake` + * Verilog and Chisel3 tests in `/vta/apps/tsim_example/tests/python` + * Verilog accelerator backend `/vta/apps/tsim_example/hardware/verilog` + * Chisel3 accelerator backend `/vta/apps/tsim_example/hardware/chisel` + * Software C++ driver (backend) that handles the accelerator `/vta/apps/tsim_example/src/driver.cc` + * Software Python driver (frontend) that handles the accelerator `/vta/apps/tsim_example/python/accel` diff --git a/vta/apps/tsim_example/cmake/modules/hw.cmake b/vta/apps/tsim_example/cmake/modules/hw.cmake deleted file mode 100644 index 102df9987752..000000000000 --- a/vta/apps/tsim_example/cmake/modules/hw.cmake +++ /dev/null @@ -1,152 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -if(MSVC) - message(STATUS "[TSIM_HW] build is skipped in Windows..") -else() - find_program(PYTHON NAMES python python3 python3.6) - find_program(VERILATOR NAMES verilator) - - if (VERILATOR AND PYTHON) - - if (TSIM_TOP_NAME STREQUAL "") - message(FATAL_ERROR "[TSIM_HW] TSIM_TOP_NAME should be defined") - endif() - - if (TSIM_BUILD_NAME STREQUAL "") - message(FATAL_ERROR "[TSIM_HW] TSIM_BUILD_NAME should be defined") - endif() - - set(TSIM_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/config/config.py) - - execute_process(COMMAND ${TSIM_CONFIG} --get-target OUTPUT_VARIABLE TSIM_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE) - execute_process(COMMAND ${TSIM_CONFIG} --get-top-name OUTPUT_VARIABLE TSIM_TOP_NAME OUTPUT_STRIP_TRAILING_WHITESPACE) - execute_process(COMMAND ${TSIM_CONFIG} --get-build-name OUTPUT_VARIABLE TSIM_BUILD_NAME OUTPUT_STRIP_TRAILING_WHITESPACE) - execute_process(COMMAND ${TSIM_CONFIG} --get-use-trace OUTPUT_VARIABLE TSIM_USE_TRACE OUTPUT_STRIP_TRAILING_WHITESPACE) - execute_process(COMMAND ${TSIM_CONFIG} --get-trace-name OUTPUT_VARIABLE TSIM_TRACE_NAME OUTPUT_STRIP_TRAILING_WHITESPACE) - - set(TSIM_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/${TSIM_BUILD_NAME}) - - if (TSIM_TARGET STREQUAL "chisel") - - find_program(SBT NAMES sbt) - - if (SBT) - - # Install Chisel VTA package for DPI modules - set(VTA_CHISEL_DIR ${VTA_DIR}/hardware/chisel) - - execute_process(WORKING_DIRECTORY ${VTA_CHISEL_DIR} - COMMAND ${SBT} publishLocal RESULT_VARIABLE RETCODE) - - if (NOT RETCODE STREQUAL "0") - message(FATAL_ERROR "[TSIM_HW] sbt failed to install VTA scala package") - endif() - - # Chisel - Scala to Verilog compilation - set(TSIM_CHISEL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/chisel) - set(CHISEL_BUILD_DIR ${TSIM_BUILD_DIR}/chisel) - set(CHISEL_OPT "test:runMain test.Elaborate --target-dir ${CHISEL_BUILD_DIR} --top-name ${TSIM_TOP_NAME}") - - execute_process(WORKING_DIRECTORY ${TSIM_CHISEL_DIR} COMMAND ${SBT} ${CHISEL_OPT} RESULT_VARIABLE RETCODE) - - if (NOT RETCODE STREQUAL "0") - message(FATAL_ERROR "[TSIM_HW] sbt failed to compile from Chisel to Verilog.") - endif() - - file(GLOB VERILATOR_RTL_SRC ${CHISEL_BUILD_DIR}/*.v) - - else() - message(FATAL_ERROR "[TSIM_HW] sbt should be installed for Chisel") - endif() # sbt - - elseif (TSIM_TARGET STREQUAL "verilog") - - set(VTA_VERILOG_DIR ${VTA_DIR}/hardware/chisel/src/main/resources/verilog) - set(TSIM_VERILOG_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/verilog) - file(GLOB VERILATOR_RTL_SRC ${VTA_VERILOG_DIR}/*.v ${TSIM_VERILOG_DIR}/*.v) - - else() - message(FATAL_ERROR "[TSIM_HW] target language can be only verilog or chisel...") - endif() # TSIM_TARGET - - if (TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog") - - # Check if tracing can be enabled - if (NOT TSIM_USE_TRACE STREQUAL "off") - message(STATUS "[TSIM_HW] Verilog enable tracing") - else() - message(STATUS "[TSIM_HW] Verilator disable tracing") - endif() - - # Verilator - Verilog to C++ compilation - set(VERILATOR_BUILD_DIR ${TSIM_BUILD_DIR}/verilator) - set(VERILATOR_OPT +define+RANDOMIZE_GARBAGE_ASSIGN +define+RANDOMIZE_REG_INIT) - list(APPEND VERILATOR_OPT +define+RANDOMIZE_MEM_INIT --x-assign unique) - list(APPEND VERILATOR_OPT --output-split 20000 --output-split-cfuncs 20000) - list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_BUILD_DIR}) - list(APPEND VERILATOR_OPT --cc ${VERILATOR_RTL_SRC}) - - if (NOT TSIM_USE_TRACE STREQUAL "off") - list(APPEND VERILATOR_OPT --trace) - endif() - - execute_process(COMMAND ${VERILATOR} ${VERILATOR_OPT} RESULT_VARIABLE RETCODE) - - if (NOT RETCODE STREQUAL "0") - message(FATAL_ERROR "[TSIM_HW] Verilator failed to compile Verilog to C++...") - endif() - - # Build shared library (.so) - set(VTA_HW_DPI_DIR ${VTA_DIR}/hardware/dpi) - if (EXISTS /usr/local/share/verilator/include) - set(VERILATOR_INC_DIR /usr/local/share/verilator/include) - elseif (EXISTS /usr/share/verilator/include) - set(VERILATOR_INC_DIR /usr/share/verilator/include) - else() - message(FATAL_ERROR "[TSIM_HW] Verilator include directory not found") - endif() - set(VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated.cpp ${VERILATOR_INC_DIR}/verilated_dpi.cpp) - - if (NOT TSIM_USE_TRACE STREQUAL "off") - list(APPEND VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated_vcd_c.cpp) - endif() - - file(GLOB VERILATOR_GEN_SRC ${VERILATOR_BUILD_DIR}/*.cpp) - file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc) - add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC}) - - set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0) - if (NOT TSIM_USE_TRACE STREQUAL "off") - list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd) - else() - list(APPEND VERILATOR_DEF VM_TRACE=0) - endif() - target_compile_definitions(hw PRIVATE ${VERILATOR_DEF}) - target_compile_options(hw PRIVATE -Wno-sign-compare -include V${TSIM_TOP_NAME}.h) - target_include_directories(hw PRIVATE ${VERILATOR_BUILD_DIR} ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd ${VTA_DIR}/include) - - if(APPLE) - set_target_properties(hw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") - endif(APPLE) - - endif() # TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog" - - else() - message(STATUS "[TSIM_HW] could not find Python or Verilator, build is skipped...") - endif() # VERILATOR -endif() # MSVC diff --git a/vta/apps/tsim_example/config/config.json b/vta/apps/tsim_example/config/config.json deleted file mode 100644 index 887eaac67d74..000000000000 --- a/vta/apps/tsim_example/config/config.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "TARGET" : "verilog", - "TOP_NAME" : "TestAccel", - "BUILD_NAME" : "build", - "USE_TRACE" : "off", - "TRACE_NAME" : "trace" -} diff --git a/vta/apps/tsim_example/config/config.py b/vta/apps/tsim_example/config/config.py deleted file mode 100644 index 6ff4f4234cf0..000000000000 --- a/vta/apps/tsim_example/config/config.py +++ /dev/null @@ -1,61 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os.path as osp -import sys -import json -import argparse - -cur = osp.abspath(osp.dirname(__file__)) -cfg = json.load(open(osp.join(cur, 'config.json'))) - -def main(): - """Main function""" - parser = argparse.ArgumentParser() - parser.add_argument("--get-target", action="store_true", - help="Get target language, i.e. verilog or chisel") - parser.add_argument("--get-top-name", action="store_true", - help="Get hardware design top name") - parser.add_argument("--get-build-name", action="store_true", - help="Get build folder name") - parser.add_argument("--get-use-trace", action="store_true", - help="Get use trace") - parser.add_argument("--get-trace-name", action="store_true", - help="Get trace filename") - args = parser.parse_args() - - if len(sys.argv) == 1: - parser.print_help() - return - - if args.get_target: - print(cfg['TARGET']) - - if args.get_top_name: - print(cfg['TOP_NAME']) - - if args.get_build_name: - print(cfg['BUILD_NAME']) - - if args.get_use_trace: - print(cfg['USE_TRACE']) - - if args.get_trace_name: - print(cfg['TRACE_NAME']) - -if __name__ == "__main__": - main() diff --git a/vta/apps/tsim_example/hardware/chisel/Makefile b/vta/apps/tsim_example/hardware/chisel/Makefile index 65a9ed13c989..463786a9a806 100644 --- a/vta/apps/tsim_example/hardware/chisel/Makefile +++ b/vta/apps/tsim_example/hardware/chisel/Makefile @@ -15,5 +15,92 @@ # specific language governing permissions and limitations # under the License. +ifeq (, $(shell which verilator)) + $(error "No Verilator in $(PATH), consider doing apt-get install verilator") +endif + +# Change VERILATOR_INC_DIR if Verilator is installed on a different location +ifeq (, $(VERILATOR_INC_DIR)) + ifeq (, $(wildcard /usr/local/share/verilator/include/*)) + ifeq (, $(wildcard /usr/share/verilator/include/*)) + $(error "Verilator include directory is not set properly") + else + VERILATOR_INC_DIR := /usr/share/verilator/include + endif + else + VERILATOR_INC_DIR := /usr/local/share/verilator/include + endif +endif + +TOP = TestAccel +BUILD_NAME = build +USE_TRACE = 0 +LIBNAME = libhw + +vta_dir = $(abspath ../../../../) +tvm_dir = $(abspath ../../../../../) +build_dir = $(abspath .)/$(BUILD_NAME) +verilator_build_dir = $(build_dir)/verilator +chisel_build_dir = $(build_dir)/chisel + +verilator_opt = --cc +verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN +verilator_opt += +define+RANDOMIZE_REG_INIT +verilator_opt += +define+RANDOMIZE_MEM_INIT +verilator_opt += --x-assign unique +verilator_opt += --output-split 20000 +verilator_opt += --output-split-cfuncs 20000 +verilator_opt += --top-module ${TOP} +verilator_opt += -Mdir ${verilator_build_dir} +verilator_opt += -I$(chisel_build_dir) + +cxx_flags = -O2 -Wall -fPIC -shared +cxx_flags += -fvisibility=hidden -std=c++11 +cxx_flags += -DVL_TSIM_NAME=V$(TOP) +cxx_flags += -DVL_PRINTF=printf +cxx_flags += -DVL_USER_FINISH +cxx_flags += -DVM_COVERAGE=0 +cxx_flags += -DVM_SC=0 +cxx_flags += -Wno-sign-compare +cxx_flags += -include V$(TOP).h +cxx_flags += -I$(verilator_build_dir) +cxx_flags += -I$(VERILATOR_INC_DIR) +cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd +cxx_flags += -I$(vta_dir)/include +cxx_flags += -I$(tvm_dir)/include +cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include + +cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp +cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp +cxx_files += $(wildcard $(verilator_build_dir)/*.cpp) +cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc + +ifneq ($(USE_TRACE), 0) + verilator_opt += --trace + cxx_flags += -DVM_TRACE=1 + cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd + cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp +else + cxx_flags += -DVM_TRACE=0 +endif + +default: lib + +lib: $(build_dir)/$(LIBNAME).so +$(build_dir)/$(LIBNAME).so: $(verilator_build_dir)/V$(TOP).cpp + echo $(cxx_files) + g++ $(cxx_flags) $(cxx_files) -o $@ + +verilator: $(verilator_build_dir)/V$(TOP).cpp +$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v + verilator $(verilator_opt) $< + +verilog: $(chisel_build_dir)/$(TOP).v +$(chisel_build_dir)/$(TOP).v: install_vta_package + sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)' + +install_vta_package: + cd $(vta_dir)/hardware/chisel && sbt publishLocal + clean: - -rm -rf target project/target project/project + -rm -rf $(build_dir) target project/target project/project diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala index 9225f83b0821..d654a7fdd41a 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala @@ -35,18 +35,28 @@ import vta.dpi._ * |_________| |_________| * */ +case class AccelConfig() { + val nCtrl = 1 + val nECnt = 1 + val nVals = 2 + val nPtrs = 2 + val regBits = 32 + val ptrBits = 2*regBits +} + class Accel extends Module { val io = IO(new Bundle { val host = new VTAHostDPIClient val mem = new VTAMemDPIMaster }) + implicit val config = AccelConfig() val rf = Module(new RegFile) val ce = Module(new Compute) rf.io.host <> io.host io.mem <> ce.io.mem ce.io.launch := rf.io.launch rf.io.finish := ce.io.finish - ce.io.length := rf.io.length - ce.io.inp_baddr := rf.io.inp_baddr - ce.io.out_baddr := rf.io.out_baddr + rf.io.ecnt <> ce.io.ecnt + ce.io.vals <> rf.io.vals + ce.io.ptrs <> rf.io.ptrs } diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala index fb7a2f396cb0..f24cbdd8bdb7 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala @@ -35,21 +35,24 @@ import vta.dpi._ * 6. Check if counter (cnt) is equal to length to assert finish, * otherwise go to step 2. */ -class Compute extends Module { +class Compute(implicit config: AccelConfig) extends Module { val io = IO(new Bundle { val launch = Input(Bool()) val finish = Output(Bool()) - val length = Input(UInt(32.W)) - val inp_baddr = Input(UInt(64.W)) - val out_baddr = Input(UInt(64.W)) + val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W))) + val vals = Input(Vec(config.nVals, UInt(config.regBits.W))) + val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W))) val mem = new VTAMemDPIMaster }) val sIdle :: sReadReq :: sReadData :: sWriteReq :: sWriteData :: Nil = Enum(5) val state = RegInit(sIdle) + val const = io.vals(0) + val length = io.vals(1) + val cycles = RegInit(0.U(config.regBits.W)) val reg = Reg(chiselTypeOf(io.mem.rd.bits)) - val cnt = Reg(chiselTypeOf(io.length)) - val raddr = Reg(chiselTypeOf(io.inp_baddr)) - val waddr = Reg(chiselTypeOf(io.out_baddr)) + val cnt = Reg(UInt(config.regBits.W)) + val raddr = Reg(UInt(config.ptrBits.W)) + val waddr = Reg(UInt(config.ptrBits.W)) switch (state) { is (sIdle) { @@ -69,7 +72,7 @@ class Compute extends Module { state := sWriteData } is (sWriteData) { - when (cnt === (io.length - 1.U)) { + when (cnt === (length - 1.U)) { state := sIdle } .otherwise { state := sReadReq @@ -77,10 +80,22 @@ class Compute extends Module { } } + val last = state === sWriteData && cnt === (length - 1.U) + + // cycle counter + when (state === sIdle) { + cycles := 0.U + } .otherwise { + cycles := cycles + 1.U + } + + io.ecnt(0).valid := last + io.ecnt(0).bits := cycles + // calculate next address when (state === sIdle) { - raddr := io.inp_baddr - waddr := io.out_baddr + raddr := io.ptrs(0) + waddr := io.ptrs(1) } .elsewhen (state === sWriteData) { // increment by 8-bytes raddr := raddr + 8.U waddr := waddr + 8.U @@ -94,7 +109,7 @@ class Compute extends Module { // read when (state === sReadData && io.mem.rd.valid) { - reg := io.mem.rd.bits + 1.U + reg := io.mem.rd.bits + const } io.mem.rd.ready := state === sReadData @@ -110,5 +125,5 @@ class Compute extends Module { } // done when read/write are equal to length - io.finish := state === sWriteData && cnt === (io.length - 1.U) + io.finish := last } diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala index e636afdfb2e1..5fdb3529573c 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala @@ -31,11 +31,13 @@ import vta.dpi._ * Register description | addr * -------------------------|----- * Control status register | 0x00 - * Length value register | 0x04 - * Input pointer lsb | 0x08 - * Input pointer msb | 0x0c - * Output pointer lsb | 0x10 - * Output pointer msb | 0x14 + * Cycle counter | 0x04 + * Constant value | 0x08 + * Vector length | 0x0c + * Input pointer lsb | 0x10 + * Input pointer msb | 0x14 + * Output pointer lsb | 0x18 + * Output pointer msb | 0x1c * ------------------------------- * ------------------------------ @@ -45,13 +47,13 @@ import vta.dpi._ * Finish | 1 * ------------------------------ */ -class RegFile extends Module { +class RegFile(implicit config: AccelConfig) extends Module { val io = IO(new Bundle { val launch = Output(Bool()) val finish = Input(Bool()) - val length = Output(UInt(32.W)) - val inp_baddr = Output(UInt(64.W)) - val out_baddr = Output(UInt(64.W)) + val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W)))) + val vals = Output(Vec(config.nVals, UInt(config.regBits.W))) + val ptrs = Output(Vec(config.nPtrs, UInt(config.regBits.W))) val host = new VTAHostDPIClient }) val sIdle :: sRead :: Nil = Enum(2) @@ -70,23 +72,34 @@ class RegFile extends Module { io.host.req.deq := state === sIdle & io.host.req.valid - val reg = Seq.fill(6)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) - val addr = Seq.tabulate(6)(_ * 4) + val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs) + val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) + val addr = Seq.tabulate(nTotal)(_ * 4) val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } + val eo = config.nCtrl + val vo = eo + config.nECnt + val po = vo + config.nVals - (reg zip addr).foreach { case(r, a) => - if (a == 0) { // control status register - when (io.finish) { - r := "b_10".U - } .elsewhen (state === sIdle && io.host.req.valid && - io.host.req.opcode && a.U === io.host.req.addr) { - r := io.host.req.value - } - } else { - when (state === sIdle && io.host.req.valid && - io.host.req.opcode && a.U === io.host.req.addr) { - r := io.host.req.value - } + when (io.finish) { + reg(0) := "b_10".U + } .elsewhen (state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(0).U === io.host.req.addr) { + reg(0) := io.host.req.value + } + + for (i <- 0 until config.nECnt) { + when (io.ecnt(i).valid) { + reg(eo + i) := io.ecnt(i).bits + } .elsewhen (state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(eo + i).U === io.host.req.addr) { + reg(eo + i) := io.host.req.value + } + } + + for (i <- 0 until (config.nVals + (2*config.nPtrs))) { + when (state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(vo + i).U === io.host.req.addr) { + reg(vo + i) := io.host.req.value } } @@ -99,7 +112,12 @@ class RegFile extends Module { io.host.resp.bits := rdata io.launch := reg(0)(0) - io.length := reg(1) - io.inp_baddr := Cat(reg(3), reg(2)) - io.out_baddr := Cat(reg(5), reg(4)) + + for (i <- 0 until config.nVals) { + io.vals(i) := reg(vo + i) + } + + for (i <- 0 until config.nPtrs) { + io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) + } } diff --git a/vta/apps/tsim_example/hardware/verilog/Makefile b/vta/apps/tsim_example/hardware/verilog/Makefile new file mode 100644 index 000000000000..8a4369aa8075 --- /dev/null +++ b/vta/apps/tsim_example/hardware/verilog/Makefile @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ifeq (, $(shell which verilator)) + $(error "No Verilator in $(PATH), consider doing apt-get install verilator") +endif + +# Change VERILATOR_INC_DIR if Verilator is installed on a different location +ifeq (, $(VERILATOR_INC_DIR)) + ifeq (, $(wildcard /usr/local/share/verilator/include/*)) + ifeq (, $(wildcard /usr/share/verilator/include/*)) + $(error "Verilator include directory is not set properly") + else + VERILATOR_INC_DIR := /usr/share/verilator/include + endif + else + VERILATOR_INC_DIR := /usr/local/share/verilator/include + endif +endif + +TOP = TestAccel +BUILD_NAME = build +USE_TRACE = 0 +LIBNAME = libhw + +vta_dir = $(abspath ../../../../) +tvm_dir = $(abspath ../../../../../) +build_dir = $(abspath .)/$(BUILD_NAME) + +verilator_opt = --cc +verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN +verilator_opt += +define+RANDOMIZE_REG_INIT +verilator_opt += +define+RANDOMIZE_MEM_INIT +verilator_opt += --x-assign unique +verilator_opt += --output-split 20000 +verilator_opt += --output-split-cfuncs 20000 +verilator_opt += --top-module ${TOP} +verilator_opt += -Mdir ${build_dir} + +cxx_flags = -O2 -Wall -fPIC -shared +cxx_flags += -fvisibility=hidden -std=c++11 +cxx_flags += -DVL_TSIM_NAME=V$(TOP) +cxx_flags += -DVL_PRINTF=printf +cxx_flags += -DVL_USER_FINISH +cxx_flags += -DVM_COVERAGE=0 +cxx_flags += -DVM_SC=0 +cxx_flags += -Wno-sign-compare +cxx_flags += -include V$(TOP).h +cxx_flags += -I$(build_dir) +cxx_flags += -I$(VERILATOR_INC_DIR) +cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd +cxx_flags += -I$(vta_dir)/include +cxx_flags += -I$(tvm_dir)/include +cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include + +cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp +cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp +cxx_files += $(wildcard $(build_dir)/*.cpp) +cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc + +v_files = $(wildcard $(abspath .)/src/*.v $(vta_dir)/hardware/chisel/src/main/resources/verilog/*.v) + +ifneq ($(USE_TRACE), 0) + verilator_opt += --trace + cxx_flags += -DVM_TRACE=1 + cxx_flags += -DTSIM_TRACE_FILE=$(build_dir)/$(TOP).vcd + cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp +else + cxx_flags += -DVM_TRACE=0 +endif + +default: lib + +lib: $(build_dir)/$(LIBNAME).so +$(build_dir)/$(LIBNAME).so: $(build_dir)/V$(TOP).cpp + g++ $(cxx_flags) $(cxx_files) -o $@ + +verilator: $(build_dir)/V$(TOP).cpp +$(build_dir)/V$(TOP).cpp: $(v_files) | $(build_dir) + verilator $(verilator_opt) $(v_files) + +$(build_dir): + mkdir -p $@ + +clean: + -rm -rf $(build_dir) diff --git a/vta/apps/tsim_example/hardware/verilog/Accel.v b/vta/apps/tsim_example/hardware/verilog/src/Accel.v similarity index 63% rename from vta/apps/tsim_example/hardware/verilog/Accel.v rename to vta/apps/tsim_example/hardware/verilog/src/Accel.v index b025aad22ab7..34d7d957a858 100644 --- a/vta/apps/tsim_example/hardware/verilog/Accel.v +++ b/vta/apps/tsim_example/hardware/verilog/src/Accel.v @@ -62,6 +62,11 @@ module Accel # logic launch; logic finish; + + logic event_counter_valid; + logic [HOST_DATA_BITS-1:0] event_counter_value; + + logic [HOST_DATA_BITS-1:0] constant; logic [HOST_DATA_BITS-1:0] length; logic [MEM_ADDR_BITS-1:0] inp_baddr; logic [MEM_ADDR_BITS-1:0] out_baddr; @@ -74,22 +79,27 @@ module Accel # ) rf ( - .clock (clock), - .reset (reset), - - .host_req_valid (host_req_valid), - .host_req_opcode (host_req_opcode), - .host_req_addr (host_req_addr), - .host_req_value (host_req_value), - .host_req_deq (host_req_deq), - .host_resp_valid (host_resp_valid), - .host_resp_bits (host_resp_bits), - - .launch (launch), - .finish (finish), - .length (length), - .inp_baddr (inp_baddr), - .out_baddr (out_baddr) + .clock (clock), + .reset (reset), + + .host_req_valid (host_req_valid), + .host_req_opcode (host_req_opcode), + .host_req_addr (host_req_addr), + .host_req_value (host_req_value), + .host_req_deq (host_req_deq), + .host_resp_valid (host_resp_valid), + .host_resp_bits (host_resp_bits), + + .launch (launch), + .finish (finish), + + .event_counter_valid (event_counter_valid), + .event_counter_value (event_counter_value), + + .constant (constant), + .length (length), + .inp_baddr (inp_baddr), + .out_baddr (out_baddr) ); Compute # @@ -101,24 +111,29 @@ module Accel # ) comp ( - .clock (clock), - .reset (reset), - - .mem_req_valid (mem_req_valid), - .mem_req_opcode (mem_req_opcode), - .mem_req_len (mem_req_len), - .mem_req_addr (mem_req_addr), - .mem_wr_valid (mem_wr_valid), - .mem_wr_bits (mem_wr_bits), - .mem_rd_valid (mem_rd_valid), - .mem_rd_bits (mem_rd_bits), - .mem_rd_ready (mem_rd_ready), - - .launch (launch), - .finish (finish), - .length (length), - .inp_baddr (inp_baddr), - .out_baddr (out_baddr) + .clock (clock), + .reset (reset), + + .mem_req_valid (mem_req_valid), + .mem_req_opcode (mem_req_opcode), + .mem_req_len (mem_req_len), + .mem_req_addr (mem_req_addr), + .mem_wr_valid (mem_wr_valid), + .mem_wr_bits (mem_wr_bits), + .mem_rd_valid (mem_rd_valid), + .mem_rd_bits (mem_rd_bits), + .mem_rd_ready (mem_rd_ready), + + .launch (launch), + .finish (finish), + + .event_counter_valid (event_counter_valid), + .event_counter_value (event_counter_value), + + .constant (constant), + .length (length), + .inp_baddr (inp_baddr), + .out_baddr (out_baddr) ); endmodule diff --git a/vta/apps/tsim_example/hardware/verilog/Compute.v b/vta/apps/tsim_example/hardware/verilog/src/Compute.v similarity index 85% rename from vta/apps/tsim_example/hardware/verilog/Compute.v rename to vta/apps/tsim_example/hardware/verilog/src/Compute.v index a5660ac8bc7d..4360b1ca20dd 100644 --- a/vta/apps/tsim_example/hardware/verilog/Compute.v +++ b/vta/apps/tsim_example/hardware/verilog/src/Compute.v @@ -52,6 +52,11 @@ module Compute # input launch, output finish, + + output event_counter_valid, + output [HOST_DATA_BITS-1:0] event_counter_value, + + input [HOST_DATA_BITS-1:0] constant, input [HOST_DATA_BITS-1:0] length, input [MEM_ADDR_BITS-1:0] inp_baddr, input [MEM_ADDR_BITS-1:0] out_baddr @@ -84,7 +89,7 @@ module Compute # IDLE: begin if (launch) begin state_n = READ_REQ; - end + end end READ_REQ: begin @@ -94,9 +99,9 @@ module Compute # READ_DATA: begin if (mem_rd_valid) begin state_n = WRITE_REQ; - end else begin + end else begin state_n = READ_DATA; - end + end end WRITE_REQ: begin @@ -106,9 +111,9 @@ module Compute # WRITE_DATA: begin if (cnt == (length - 1'b1)) begin state_n = IDLE; - end else begin + end else begin state_n = READ_REQ; - end + end end default: begin @@ -116,6 +121,22 @@ module Compute # endcase end + logic last; + assign last = (state_r == WRITE_DATA) & (cnt == (length - 1'b1)); + + // cycle counter + logic [HOST_DATA_BITS-1:0] cycle_counter; + always_ff @(posedge clock) begin + if (reset | state_r == IDLE) begin + cycle_counter <= '0; + end else begin + cycle_counter <= cycle_counter + 1'b1; + end + end + + assign event_counter_valid = last; + assign event_counter_value = cycle_counter; + // calculate next address always_ff @(posedge clock) begin if (reset | state_r == IDLE) begin @@ -136,7 +157,7 @@ module Compute # // read always_ff @(posedge clock) begin if ((state_r == READ_DATA) & mem_rd_valid) begin - data <= mem_rd_bits + 1'b1; + data <= mem_rd_bits + {32'd0, constant}; end end assign mem_rd_ready = state_r == READ_DATA; @@ -155,5 +176,5 @@ module Compute # end // done when read/write are equal to length - assign finish = (state_r == WRITE_DATA) & (cnt == (length - 1'b1)); + assign finish = last; endmodule diff --git a/vta/apps/tsim_example/hardware/verilog/RegFile.v b/vta/apps/tsim_example/hardware/verilog/src/RegFile.v similarity index 72% rename from vta/apps/tsim_example/hardware/verilog/RegFile.v rename to vta/apps/tsim_example/hardware/verilog/src/RegFile.v index 28edf9672f48..7174682dc8a2 100644 --- a/vta/apps/tsim_example/hardware/verilog/RegFile.v +++ b/vta/apps/tsim_example/hardware/verilog/src/RegFile.v @@ -25,11 +25,13 @@ * Register description | addr * -------------------------|----- * Control status register | 0x00 - * Length value register | 0x04 - * Input pointer lsb | 0x08 - * Input pointer msb | 0x0c - * Output pointer lsb | 0x10 - * Output pointer msb | 0x14 + * Cycle counter | 0x04 + * Constant value | 0x08 + * Vector length | 0x0c + * Input pointer lsb | 0x10 + * Input pointer msb | 0x14 + * Output pointer lsb | 0x18 + * Output pointer msb | 0x1c * ------------------------------- * ------------------------------ @@ -58,11 +60,18 @@ module RegFile # output launch, input finish, + + input event_counter_valid, + input [HOST_DATA_BITS-1:0] event_counter_value, + + output [HOST_DATA_BITS-1:0] constant, output [HOST_DATA_BITS-1:0] length, output [MEM_ADDR_BITS-1:0] inp_baddr, output [MEM_ADDR_BITS-1:0] out_baddr ); + localparam NUM_REG = 8; + typedef enum logic {IDLE, READ} state_t; state_t state_n, state_r; @@ -80,7 +89,7 @@ module RegFile # IDLE: begin if (host_req_valid & ~host_req_opcode) begin state_n = READ; - end + end end READ: begin @@ -91,28 +100,49 @@ module RegFile # assign host_req_deq = (state_r == IDLE) ? host_req_valid : 1'b0; - logic [HOST_DATA_BITS-1:0] rf [5:0]; + logic [HOST_DATA_BITS-1:0] rf [NUM_REG-1:0]; genvar i; - for (i = 0; i < 6; i++) begin + for (i = 0; i < NUM_REG; i++) begin + logic wen = (state_r == IDLE)? host_req_valid & host_req_opcode & i*4 == host_req_addr : 1'b0; + if (i == 0) begin + always_ff @(posedge clock) begin if (reset) begin - end else if (finish) begin - rf[i] <= 'd2; - end else if (wen) begin - rf[i] <= host_req_value; - end + rf[i] <= 'd0; + end else if (finish) begin + rf[i] <= 'd2; + end else if (wen) begin + rf[i] <= host_req_value; + end end + + end else if (i == 1) begin + + always_ff @(posedge clock) begin + if (reset) begin + rf[i] <= 'd0; + end else if (event_counter_valid) begin + rf[i] <= event_counter_value; + end else if (wen) begin + rf[i] <= host_req_value; + end + end + end else begin + always_ff @(posedge clock) begin if (reset) begin - end else if (wen) begin - rf[i] <= host_req_value; - end + rf[i] <= 'd0; + end else if (wen) begin + rf[i] <= host_req_value; + end end + end + end logic [HOST_DATA_BITS-1:0] rdata; @@ -132,6 +162,10 @@ module RegFile # rdata <= rf[4]; end else if (host_req_addr == 'h14) begin rdata <= rf[5]; + end else if (host_req_addr == 'h18) begin + rdata <= rf[6]; + end else if (host_req_addr == 'h1c) begin + rdata <= rf[7]; end else begin rdata <= 'd0; end @@ -142,8 +176,9 @@ module RegFile # assign host_resp_bits = rdata; assign launch = rf[0][0]; - assign length = rf[1]; - assign inp_baddr = {rf[3], rf[2]}; - assign out_baddr = {rf[5], rf[4]}; + assign constant = rf[2]; + assign length = rf[3]; + assign inp_baddr = {rf[5], rf[4]}; + assign out_baddr = {rf[7], rf[6]}; endmodule diff --git a/vta/apps/tsim_example/hardware/verilog/TestAccel.v b/vta/apps/tsim_example/hardware/verilog/src/TestAccel.v similarity index 100% rename from vta/apps/tsim_example/hardware/verilog/TestAccel.v rename to vta/apps/tsim_example/hardware/verilog/src/TestAccel.v diff --git a/vta/apps/tsim_example/python/tsim/__init__.py b/vta/apps/tsim_example/python/accel/__init__.py similarity index 100% rename from vta/apps/tsim_example/python/tsim/__init__.py rename to vta/apps/tsim_example/python/accel/__init__.py diff --git a/vta/apps/tsim_example/python/tsim/driver.py b/vta/apps/tsim_example/python/accel/driver.py similarity index 62% rename from vta/apps/tsim_example/python/tsim/driver.py rename to vta/apps/tsim_example/python/accel/driver.py index c388b99cbec9..6d8e7181b707 100644 --- a/vta/apps/tsim_example/python/tsim/driver.py +++ b/vta/apps/tsim_example/python/accel/driver.py @@ -17,31 +17,25 @@ import tvm import ctypes -import json import os.path as osp from sys import platform -def driver(hw_lib, sw_lib): - """Init hardware and software shared library for add-by-one accelerator +def driver(hw_backend): + """Init hardware and software shared library for accelerator Parameters ------------ - hw_lib : str - Name of hardware shared library + hw_backend : str + Hardware backend can be verilog or chisel - sw_lib : str - Name of software shared library """ + _ext = ".dylib" if platform == "darwin" else ".so" + _hw_libname = "libhw" + _ext + _sw_libname = "libsw" + _ext _cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) - _root_path = osp.join(_cur_path, "..", "..") - _cfg_file = osp.join(_root_path, "config", "config.json") - _cfg = json.load(open(_cfg_file)) - if not hw_lib.endswith(("dylib", "so")): - hw_lib += ".dylib" if platform == "darwin" else ".so" - if not sw_lib.endswith(("dylib", "so")): - sw_lib += ".dylib" if platform == "darwin" else ".so" - _hw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], hw_lib) - _sw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], sw_lib) + if hw_backend in ("verilog", "chisel"): + _hw_lib = osp.join(_cur_path, "..", "..", "hardware", hw_backend, "build", _hw_libname) + _sw_lib = osp.join(_cur_path, "..", "..", "build", _sw_libname) def load_dll(dll): try: @@ -49,9 +43,9 @@ def load_dll(dll): except OSError: return [] - def run(a, b): + def run(a, b, c): load_dll(_sw_lib) f = tvm.get_global_func("tvm.vta.driver") m = tvm.module.load(_hw_lib, "vta-tsim") - f(m, a, b) + return f(m, a, b, c) return run diff --git a/vta/apps/tsim_example/src/driver.cc b/vta/apps/tsim_example/src/driver.cc index c11a8f8a3ee7..ad9d6ddf2620 100644 --- a/vta/apps/tsim_example/src/driver.cc +++ b/vta/apps/tsim_example/src/driver.cc @@ -43,34 +43,40 @@ class Device { module.operator->()); } - int Run(uint32_t length, void* inp, void* out) { - uint32_t wait_cycles = 100000000; - this->Launch(wait_cycles, length, inp, out); - this->WaitForCompletion(wait_cycles); + uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) { + uint32_t cycles; + this->Launch(c, length, inp, out); + cycles = this->WaitForCompletion(); dpi_->Finish(); - return 0; + return cycles; } private: - void Launch(uint32_t wait_cycles, uint32_t length, void* inp, void* out) { - dpi_->Launch(wait_cycles); - // write registers - dpi_->WriteReg(0x04, length); - dpi_->WriteReg(0x08, get_half_addr(inp, false)); - dpi_->WriteReg(0x0c, get_half_addr(inp, true)); - dpi_->WriteReg(0x10, get_half_addr(out, false)); - dpi_->WriteReg(0x14, get_half_addr(out, true)); - dpi_->WriteReg(0x00, 0x1); // launch + void Launch(uint32_t c, uint32_t length, void* inp, void* out) { + dpi_->Launch(wait_cycles_); + // set counter to zero + dpi_->WriteReg(0x04, 0); + dpi_->WriteReg(0x08, c); + dpi_->WriteReg(0x0c, length); + dpi_->WriteReg(0x10, get_half_addr(inp, false)); + dpi_->WriteReg(0x14, get_half_addr(inp, true)); + dpi_->WriteReg(0x18, get_half_addr(out, false)); + dpi_->WriteReg(0x1c, get_half_addr(out, true)); + // launch + dpi_->WriteReg(0x00, 0x1); } - void WaitForCompletion(uint32_t wait_cycles) { + uint32_t WaitForCompletion() { uint32_t i, val; - for (i = 0; i < wait_cycles; i++) { + for (i = 0; i < wait_cycles_; i++) { val = dpi_->ReadReg(0x00); - if (val == 2) break; // finish + if (val == 2) break; // finish } + val = dpi_->ReadReg(0x04); + return val; } + uint32_t wait_cycles_{100000000}; DPIModuleNode* dpi_; Module module_; }; @@ -84,7 +90,8 @@ TVM_REGISTER_GLOBAL("tvm.vta.driver") DLTensor* A = args[1]; DLTensor* B = args[2]; Device dev_(dev_mod); - dev_.Run(A->shape[0], A->data, B->data); + uint32_t cycles = dev_.Run(static_cast(args[3]), A->shape[0], A->data, B->data); + *rv = static_cast(cycles); }); } // namespace driver diff --git a/vta/apps/tsim_example/tests/python/add_by_one.py b/vta/apps/tsim_example/tests/python/chisel_accel.py similarity index 71% rename from vta/apps/tsim_example/tests/python/add_by_one.py rename to vta/apps/tsim_example/tests/python/chisel_accel.py index 6e0d094367b5..6ab0bf5a36eb 100644 --- a/vta/apps/tsim_example/tests/python/add_by_one.py +++ b/vta/apps/tsim_example/tests/python/chisel_accel.py @@ -18,22 +18,21 @@ import tvm import numpy as np -from tsim.driver import driver +from accel.driver import driver -def test_tsim(i): - rmin = 1 # min vector size of 1 +def test_accel(): rmax = 64 - n = np.random.randint(rmin, rmax) + n = np.random.randint(1, rmax) + c = np.random.randint(0, rmax) ctx = tvm.cpu(0) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) - f = driver("libhw", "libsw") - f(a, b) - emsg = "[FAIL] test number:{} n:{}".format(i, n) - np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1, err_msg=emsg) - print("[PASS] test number:{} n:{}".format(i, n)) + f = driver("chisel") + cycles = f(a, b, c) + msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg) + print("[PASS] " + msg) if __name__ == "__main__": - times = 10 - for i in range(times): - test_tsim(i) + for i in range(10): + test_accel() diff --git a/vta/apps/tsim_example/cmake/modules/sw.cmake b/vta/apps/tsim_example/tests/python/verilog_accel.py similarity index 56% rename from vta/apps/tsim_example/cmake/modules/sw.cmake rename to vta/apps/tsim_example/tests/python/verilog_accel.py index d0368c3edc75..97f636cbfde1 100644 --- a/vta/apps/tsim_example/cmake/modules/sw.cmake +++ b/vta/apps/tsim_example/tests/python/verilog_accel.py @@ -15,10 +15,24 @@ # specific language governing permissions and limitations # under the License. -file(GLOB TSIM_SW_SRC src/driver.cc) -add_library(sw SHARED ${TSIM_SW_SRC}) -target_include_directories(sw PRIVATE ${VTA_DIR}/include) +import tvm +import numpy as np -if(APPLE) - set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") -endif(APPLE) +from accel.driver import driver + +def test_accel(): + rmax = 64 + n = np.random.randint(1, rmax) + c = np.random.randint(0, rmax) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) + b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) + f = driver("verilog") + cycles = f(a, b, c) + msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg) + print("[PASS] " + msg) + +if __name__ == "__main__": + for i in range(10): + test_accel() diff --git a/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v b/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v index 8ab85f6b752c..b466c79d4555 100644 --- a/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v +++ b/vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v @@ -112,7 +112,6 @@ module VTAHostDPI # always_ff @(posedge clock) begin if (__exit == 'd1) begin - $display("[TSIM] Verilog $finish called at cycle:%016d", cycles); $finish; end end diff --git a/vta/hardware/dpi/tsim_device.cc b/vta/hardware/dpi/tsim_device.cc index 0b315e4cb541..aa05c8c2663c 100644 --- a/vta/hardware/dpi/tsim_device.cc +++ b/vta/hardware/dpi/tsim_device.cc @@ -75,7 +75,6 @@ void VTADPIInit(VTAContextHandle handle, // VL_USER_FINISH needs to be defined when compiling Verilator code void vl_finish(const char* filename, int linenum, const char* hier) { Verilated::gotFinish(true); - VL_PRINTF("[TSIM] exiting simulation\n"); } int VTADPISim(uint64_t max_cycles) { From 590c457e742e6ef43f9cf4ce36e292a22120a640 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Tue, 11 Jun 2019 21:10:51 -0700 Subject: [PATCH 119/176] Non_maximum_suppression and get_valid_counts add new parameters (#3335) --- topi/python/topi/cuda/nms.py | 53 ++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index c0da4a45ec8d..417e2dce5774 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -27,7 +27,7 @@ from .. import tag -def get_valid_counts_pre(data, flag, idx, score_threshold): +def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index): """Low level IR to Prepare get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -46,6 +46,12 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): score_threshold : float32 Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- stmt : Stmt @@ -61,6 +67,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): flag = ib.buffer_ptr(flag) idx = ib.buffer_ptr(idx) score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) + id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) + score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads @@ -72,7 +80,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): tid = bx * max_threads + tx with ib.if_scope(tid < batch_size * num_anchors): - with ib.if_scope(data[tid * box_data_length + 1] > score_threshold): + with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \ + tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))): flag[tid] = 1 idx[tid] = 1 with ib.else_scope(): @@ -356,7 +365,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): temp_flag, temp_idx = \ tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], lambda ins, outs: get_valid_counts_pre( - ins[0], outs[0], outs[1], score_threshold), + ins[0], outs[0], outs[1], score_threshold, id_index, score_index), dtype=["int32", "int32"], out_buffers=[temp_flag_buf, temp_idx_buf], name="get_valid_counts_phase_one") @@ -395,7 +404,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): def nms_ir(data, sorted_index, valid_count, out, box_indices, max_output_size, iou_threshold, force_suppress, - top_k, coord_start, id_index): + top_k, coord_start, id_index, score_index): """Low level IR routing for transform location in multibox_detection operator. Parameters @@ -431,6 +440,9 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices, id_index : int index of the class categories, -1 to disable. + score_index : optional, int + Index of the scores/confidence of boxes. + Returns ------- stmt : Stmt @@ -477,6 +489,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start) id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) + score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) with ib.for_range(0, batch_size, for_type="unroll") as i: @@ -498,20 +511,26 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 box_indices[i * num_anchors + (j + nkeep)] = -1 # Apply nms - with ib.if_scope(j < valid_count[i]): - offset_j = j * box_data_length - with ib.if_scope(out[base_idx + offset_j] >= 0): - with ib.for_range(0, valid_count[i]) as k: - offset_k = k * box_data_length - with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \ + with ib.for_range(0, valid_count[i]) as k: + offset_k = k * box_data_length + with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \ + tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))): + with ib.if_scope(j < valid_count[i]): + offset_j = j * box_data_length + with ib.if_scope(tvm.all(j > k, \ + out[base_idx + offset_j + score_index] > 0, \ + tvm.any(id_index < 0, \ + out[base_idx + offset_j + id_index] >= 0), \ tvm.any(force_suppress > 0, id_index < 0, \ - out[base_idx + offset_j] == \ - out[base_idx + offset_k]))): - iou = calculate_overlap(out, base_idx + offset_k + coord_start, - base_idx + offset_j + coord_start) + out[base_idx + offset_k + id_index] == \ + out[base_idx + offset_j + id_index]))): + iou = calculate_overlap(out, base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start) with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_k] = -1.0 - box_indices[i * num_anchors + k] = -1 + out[base_idx + offset_j + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_j + id_index] = -1.0 + box_indices[i * num_anchors + j] = -1 with ib.else_scope(): with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length @@ -749,7 +768,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], outs[0], outs[1], max_output_size, iou_threshold, force_suppress, - top_k, coord_start, id_index), + top_k, coord_start, id_index, score_index), dtype=[data.dtype, "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], name="nms", From 7e8de9bda63c94d7bf9379d3a056320490cb8199 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Wed, 12 Jun 2019 05:22:22 +0100 Subject: [PATCH 120/176] [DOC] minor grammatical improvements (#3341) --- docs/dev/codebase_walkthrough.rst | 6 ++-- python/tvm/tensor.py | 8 ++--- tutorials/autotvm/tune_simple_template.py | 2 +- tutorials/cross_compilation_and_rpc.py | 44 +++++++++++------------ 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 6aa175c3f114..e6df9becdc9e 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -33,9 +33,9 @@ At the root of the TVM repository, we have following subdirectories that togethe - ``topi`` - Compute definitions and backend schedules for standard neural network operators. - ``nnvm`` - C++ code and Python frontend for graph optimization and compilation. After the introduction of Relay, it remains in the codebase for backward compatibility. -Using standard Deep Learning terminologies, ``src/relay`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in the rest of ``src``. ``python`` provides python bindings for the C++ API and driver code that users can use to execute compilation. Operators corresponding to each node are registered in ``src/relay/op``. Implementations for operators are in ``topi``, and they are coded in either C++ or Python. +Using standard Deep Learning terminology, ``src/relay`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructure implemented in the rest of ``src``. ``python`` provides python bindings for the C++ API and driver code that users can use to execute compilation. Operators corresponding to each node are registered in ``src/relay/op``. Implementations of operators are in ``topi``, and they are coded in either C++ or Python. -Relay is the new IR for deep networks that is intended to replace NNVM. If you have used NNVM, Relay provides equivalent or better functionalities. In fact, Relay goes beyond a traditional way of thinking deep networks in terms of computational graphs. But for the purpose of this document, we can think of Relay as a traditional computational graph framework. You can read more about Relay `here `_. +Relay is the new IR for deep networks that is intended to replace NNVM. If you have used NNVM, Relay provides equivalent or better functionality. In fact, Relay goes beyond a traditional way of thinking deep networks in terms of computational graphs. But for the purpose of this document, we can think of Relay as a traditional computational graph framework. You can read more about Relay `here `_. When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compiler.build(...)`` for the older API), the following sequence of actions happens for each node in the graph: @@ -43,7 +43,7 @@ When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compile - Generate a compute expression and a schedule for the operator - Compile the operator into object code -One of the interesting aspects of TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that does heavy lifting is implemented in C++, and Python bindings are provided for the user interface. This is also true in TVM, but in TVM codebase, C++ code can also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. +One of the interesting aspects of the TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that performs heavy lifting is implemented in C++, and Python bindings are provided for the user interface. This is also true in TVM, but in the TVM codebase, C++ code can also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. ******************************************* Vector Add Example diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index ce7cbae385d9..db8fb272a551 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -147,7 +147,7 @@ def output(self, index): @property def num_outputs(self): - """Number of outputs of this op.""" + """Number of outputs from this op.""" return _api_internal._OpNumOutputs(self) @property @@ -166,7 +166,7 @@ class BaseComputeOp(Operation): """Compute operation.""" @property def axis(self): - """Represent axis of IterVar, defined when it is a ComputeOp""" + """Represent the IterVar axis, defined when it is a ComputeOp""" return self.__getattr__("axis") @property @@ -191,7 +191,7 @@ class ScanOp(Operation): """Scan operation.""" @property def scan_axis(self): - """Represent axis of scan, only defined when it is a ScanOp""" + """Represent the scan axis, only defined when it is a ScanOp""" return self.__getattr__("scan_axis") @@ -205,7 +205,7 @@ class HybridOp(Operation): """Hybrid operation.""" @property def axis(self): - """Represent axis of IterVar, also defined when it is a HybridOp""" + """Represent the IterVar axis, also defined when it is a HybridOp""" return self.__getattr__("axis") diff --git a/tutorials/autotvm/tune_simple_template.py b/tutorials/autotvm/tune_simple_template.py index 0a7b9f2dd816..dc1b2ce4a4fd 100644 --- a/tutorials/autotvm/tune_simple_template.py +++ b/tutorials/autotvm/tune_simple_template.py @@ -133,7 +133,7 @@ def matmul_v1(N, L, M, dtype): # Here we make four modifications to the previous schedule code and get # a tunable "template". We can explain the modifications one by one. # -# 1. Use a decorator to mark this function as a simple template +# 1. Use a decorator to mark this function as a simple template. # 2. Get a config object: # You can regard this :code:`cfg` as an argument of this function but # we obtain it in a different way. With this argument, this function is no longer diff --git a/tutorials/cross_compilation_and_rpc.py b/tutorials/cross_compilation_and_rpc.py index ea1b88cbf96a..a75398402f9f 100644 --- a/tutorials/cross_compilation_and_rpc.py +++ b/tutorials/cross_compilation_and_rpc.py @@ -24,11 +24,11 @@ This tutorial introduces cross compilation and remote device execution with RPC in TVM. -With cross compilation and RPC, you can **compile program on your +With cross compilation and RPC, you can **compile a program on your local machine then run it on the remote device**. It is useful when -the resource of remote devices is limited, like Raspberry Pi and mobile -platforms. In this tutorial, we will take Raspberry Pi for CPU example -and Firefly-RK3399 for opencl example. +the remote device resource are limited, like Raspberry Pi and mobile +platforms. In this tutorial, we will use the Raspberry Pi for a CPU example +and the Firefly-RK3399 for an OpenCL example. """ ###################################################################### @@ -39,9 +39,9 @@ # # .. note:: # -# All instructions in both this section and next section should be -# executed on the target device, e.g. Raspberry Pi. And we assume it -# has Linux running. +# All instructions in both this section and the next section should be +# executed on the target device, e.g. Raspberry Pi. We assume the target +# is running Linux. # # Since we do compilation on the local machine, the remote device is only used # for running the generated code. We only need to build the TVM runtime on @@ -53,7 +53,7 @@ # cd tvm # make runtime -j2 # -# After building runtime successfully, we need to set environment variables +# After building the runtime successfully, we need to set environment variables # in :code:`~/.bashrc` file. We can edit :code:`~/.bashrc` # using :code:`vi ~/.bashrc` and add the line below (Assuming your TVM # directory is in :code:`~/tvm`): @@ -88,7 +88,7 @@ # # .. note:: # -# Now we back to the local machine, which has a full TVM installed +# Now we go back to the local machine, which has a full TVM installed # (with LLVM). # # Here we will declare a simple kernel on the local machine: @@ -127,15 +127,15 @@ # .. note:: # # To run this tutorial with a real remote device, change :code:`local_demo` -# to False and replace :code:`target` in :code:`build` with the true -# target triple of your device. The target triple which might be +# to False and replace :code:`target` in :code:`build` with the appropriate +# target triple for your device. The target triple which might be # different for different devices. For example, it is # :code:`'llvm -target=armv7l-linux-gnueabihf'` for Raspberry Pi 3B and # :code:`'llvm -target=aarch64-linux-gnu'` for RK3399. # -# Usually, you can query the target by execute :code:`gcc -v` on your -# device, and look for the line starting with :code:`Target:` -# (Though it may be still a loose configuration.) +# Usually, you can query the target by running :code:`gcc -v` on your +# device, and looking for the line starting with :code:`Target:` +# (Though it may still be a loose configuration.) # # Besides :code:`-target`, you can also set other compilation options # like: @@ -160,7 +160,7 @@ ###################################################################### # Run CPU Kernel Remotely by RPC # ------------------------------ -# We show how to run the generated cpu kernel on the remote device. +# We show how to run the generated CPU kernel on the remote device. # First we obtain an RPC session from remote device. if local_demo: @@ -200,8 +200,8 @@ ######################################################################### # Run OpenCL Kernel Remotely by RPC # --------------------------------- -# As for remote OpenCL devices, the workflow is almost the same as above. -# You can define the kernel, upload files, and run by RPC. +# For remote OpenCL devices, the workflow is almost the same as above. +# You can define the kernel, upload files, and run via RPC. # # .. note:: # @@ -209,7 +209,7 @@ # Firefly-RK3399. You may follow this `tutorial `_ # to setup the OS and OpenCL driver for RK3399. # -# Also we need to build the runtime with OpenCL enabled on rk3399 board. In the tvm +# Also we need to build the runtime with OpenCL enabled on rk3399 board. In the TVM # root directory, execute # # .. code-block:: bash @@ -218,7 +218,7 @@ # sed -i "s/USE_OPENCL OFF/USE_OPENCL ON/" config.cmake # make runtime -j4 # -# The following function shows how we run OpenCL kernel remotely +# The following function shows how we run an OpenCL kernel remotely def run_opencl(): # NOTE: This is the setting for my rk3399 board. You need to modify @@ -256,7 +256,7 @@ def run_opencl(): # This tutorial provides a walk through of cross compilation and RPC # features in TVM. # -# - Set up RPC server on the remote device. -# - Set up target device configuration to cross compile kernel on the +# - Set up an RPC server on the remote device. +# - Set up the target device configuration to cross compile the kernels on the # local machine. -# - Upload and run the kernel remotely by RPC API. +# - Upload and run the kernels remotely via the RPC API. From 515cbb2af9ff6db5a04e2159bdab5563d4ba6a18 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Wed, 12 Jun 2019 05:22:48 +0100 Subject: [PATCH 121/176] [DOC] clarfiy explanation (#3340) --- include/tvm/schedule.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 774d7cd9a40a..659b42aa1afa 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -102,8 +102,8 @@ class Stage : public NodeRef { */ EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar); /*! - * \brief Set predicate under which store to the array can be performed. - * Use this when there are duplicated threads doing the same store and we only + * \brief Set the predicate to determine whether a store to the array should be performed. + * Use this when there are multiple threads performing the same store and we only * need one of them to do the store. * * \note This is a dangerous scheduling primitive that can change behavior of program. From b131be2befaadf4c19b07a6b2a6c28c032530d20 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 12 Jun 2019 09:58:15 -0700 Subject: [PATCH 122/176] [Relay][Backend] Fix interpreter argument conversion for tuples. (#3349) * Support taking a tuple as an argument * Add test --- python/tvm/relay/backend/interpreter.py | 2 ++ .../python/relay/test_backend_interpreter.py | 28 ++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index fc47f4e1b7c8..593cf7cfbdf7 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -118,6 +118,8 @@ def _arg_to_ast(arg): return Constant(arg.data.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): return Tuple([_arg_to_ast(field) for field in arg.fields]) + elif isinstance(arg, tuple): + return Tuple([_arg_to_ast(field) for field in arg]) elif isinstance(arg, RefValue): return RefCreate(_arg_to_ast(arg.value)) elif isinstance(arg, ConstructorValue): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index e8a99e14d741..1e5e2310e927 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -217,6 +217,31 @@ def test_function_taking_adt_ref_tuple(): tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(), tuple_value.fields[i].asnumpy()) +def test_tuple_passing(): + x = relay.var('x', type_annotation=relay.ty.TupleType([ + relay.ty.TensorType((), 'int64'), + relay.ty.TensorType((), 'int64')])) + + fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) + mod = relay.Module({}) + gv = relay.GlobalVar('fn') + mod[gv] = fn + mod.entry_func = gv + mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod) + + ctx = tvm.cpu() + target = tvm.target.create('llvm') + exec = relay.create_executor(mod=mod, ctx=ctx, target=target) + f = exec.evaluate(gv) + # First use a Python tuple. + out = f((10, 8)) + tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) + # Second use a tuple value. + value_tuple = TupleValue( + TensorValue(np.array(11)), + TensorValue(np.array(12))) + out = f(value_tuple) + tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) if __name__ == "__main__": test_id() @@ -231,4 +256,5 @@ def test_function_taking_adt_ref_tuple(): test_tensor_value() test_tuple_value() test_tuple_getitem() - test_function_taking_adt_ref_tuple() \ No newline at end of file + test_function_taking_adt_ref_tuple() + test_tuple_passing() From 7f4510c2cbfdc59036b49a22d1bcebdfc9f47957 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 12 Jun 2019 14:23:15 -0700 Subject: [PATCH 123/176] [Relay][Frontend] Fix MxNet RNN without providing state initialization as input (#3326) --- python/tvm/relay/frontend/mxnet.py | 48 ++++++++++++++++++--- tests/python/frontend/mxnet/test_forward.py | 48 +++++++++++++-------- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 81ef51b91336..ff5f81dc7069 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -93,6 +93,15 @@ def impl(inputs, attrs): return impl +def _mx_zeros(inputs, attrs): + assert len(inputs) == 0 + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + if 0 in shape: + return None + return _op.zeros(shape=shape, dtype=dtype) + + def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: @@ -754,9 +763,30 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_data = inputs[0] concat_weight = inputs[1] - concat_states = inputs[2:] - seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0]) + init_states = inputs[2:] + + data_shape = ir_pass.infer_type(seq_data).checked_type.shape + seq_len = int(data_shape[0]) assert len(concat_weight) == num_layers * 4 + output_states = True + for idx, state in enumerate(init_states[:]): + if isinstance(state, dict): + node = state + attrs = StrAttrsDict(node.get("attrs", {})) + op_name = node["op"] + # by default, RNN layer uses zeros to initialize states + assert op_name == "_zeros" + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + init_layout = attrs.get_str("__layout__") + new_shape = list(shape) + for i, dim in enumerate(shape): + if dim == 0: + axis = layout.find(init_layout[i]) + assert axis >= 0 + new_shape[i] = int(data_shape[axis]) + init_states[idx] = _op.zeros(new_shape, dtype) + output_states = False weights = [] bias = [] @@ -768,7 +798,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): for j in range(2): w.append(concat_weight[i*2 + j].args[0]) b.append(concat_weight[num_layers*2 + i*2 + j].args[0]) - for state in concat_states: + for state in init_states: s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) weights.append(w) bias.append(b) @@ -789,8 +819,9 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_output.append(out) outputs = [_op.stack(seq_output, axis=0)] - for i in range(num_states): - outputs.append(_op.stack([s[i] for s in states], axis=0)) + if output_states: + for i in range(num_states): + outputs.append(_op.stack([s[i] for s in states], axis=0)) return outputs @@ -881,7 +912,6 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "argmin" : _arg_reduce(_op.argmin), # init ops "_ones" : _init_op(_op.ones), - "_zeros" : _init_op(_op.zeros), # softmax "softmax" : _softmax_op(_op.nn.softmax), "log_softmax" : _softmax_op(_op.nn.log_softmax), @@ -895,6 +925,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations + "_zeros" : _mx_zeros, "FullyConnected": _mx_fully_connected, "Activation" : _mx_activations, "Convolution" : _mx_conv2d, @@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] elif op_name in _convert_map: res = _convert_map[op_name](children, attrs) - if isinstance(res, (_expr.TupleWrapper, tuple, list)): + if res is None: + # defer conversion, used in RNN state initialization + res = [node] + elif isinstance(res, (_expr.TupleWrapper, tuple, list)): pass elif isinstance(res, _expr.Expr): res = [res] diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 7569257830af..8d7c15bb0be5 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -536,7 +536,7 @@ def test_forward_bilinear_resize(): verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) def test_forward_rnn_layer(): - def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): + def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True): if mode == "rnn": layer = gluon.rnn.RNN(hidden_size, num_layers) elif mode == "gru": @@ -545,23 +545,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): layer = gluon.rnn.LSTM(hidden_size, num_layers) num_states = 2 if mode == "lstm" else 1 layer.initialize() + layer.hybridize() dtype = "float32" + batch = 1 data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype) - states_np = [] - states_mx = [] - shape_dict = {'data0': data_np.shape} - inputs = {'data0': data_np} - for i in range(num_states): - s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) - states_np.append(s) - states_mx.append(mx.nd.array(s)) - shape_dict['data%s' % (i+1)] = s.shape - inputs['data%s' % (i+1)] = s + data_mx = mx.nd.array(data_np) + + if init_states: + shape_dict = {'data0': data_np.shape} + inputs = {'data0': data_np} + states_np = [] + states_mx = [] + for i in range(num_states): + s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) + states_np.append(s) + states_mx.append(mx.nd.array(s)) + shape_dict['data%s' % (i+1)] = s.shape + inputs['data%s' % (i+1)] = s + mx_out, mx_states = layer(data_mx, states_mx) + mx_res = [mx_out] + mx_states + else: + shape_dict = {'data': data_np.shape} + inputs = {'data': data_np} + mx_res = layer(data_mx) - layer.hybridize() - mx_out, mx_states = layer(mx.nd.array(data_np), states_mx) - mx_res = [mx_out] + mx_states mx_sym = layer._cached_graph[1] mx_params = {} for name, param in layer.collect_params().items(): @@ -574,14 +582,20 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): for kind in ["graph"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(new_sym)(**inputs, **params) - assert len(op_res) == len(mx_res) - for i, val in enumerate(op_res): - tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3) + if init_states: + assert len(op_res) == len(mx_res) + for i, val in enumerate(op_res): + tvm.testing.assert_allclose( + val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3) + else: + tvm.testing.assert_allclose( + op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3) for mode in ["rnn", "gru", "lstm"]: verify(mode, 64, 10, 64, 1) verify(mode, 64, 10, 64, 2) verify(mode, 64, 10, 32, 2) + verify(mode, 64, 10, 64, 2, init_states=False) def test_forward_Crop(): def verify(xshape, yshape, offset=None): From 4fcf58260b97d0eb061f33f83f8056926300c339 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 12 Jun 2019 14:56:18 -0700 Subject: [PATCH 124/176] [Relay] add ClipByValue and Neg in tf frontend converter (#3211) --- python/tvm/relay/frontend/tensorflow.py | 9 ++++++ .../frontend/tensorflow/test_forward.py | 29 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4f241952db2e..ba076cc2819f 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -941,6 +941,13 @@ def _impl(inputs, attr, params): return AttrCvt(op_name="where")(inputs, attr) return _impl +def _clip_by_value(): + def _impl(inputs, attr, params): + a_min = params.pop(inputs[1].name_hint).asnumpy()[0] + a_max = params.pop(inputs[2].name_hint).asnumpy()[0] + return _op.clip(inputs[0], a_min=a_min, a_max=a_max) + return _impl + def _reverse_v2(): def _impl(inputs, attr, params): axis = _get_num_param(params, inputs[1]) @@ -1212,6 +1219,7 @@ def _impl(inputs, attr, params): 'Cast' : _cast(), 'Ceil' : AttrCvt('ceil'), 'CheckNumerics' : _check_numerics(), + 'ClipByValue' : _clip_by_value(), 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), @@ -1245,6 +1253,7 @@ def _impl(inputs, attr, params): 'Mean' : _mean(), 'Minimum' : _elemwise('minimum'), 'Mul' : _elemwise('multiply'), + 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), 'Pack' : _pack(), 'Pad' : _pad('Pad'), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 3899bc04d5c6..498c4735a9e8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -833,6 +833,23 @@ def test_forward_tile(): _test_tile((2, 4, 6), (6, 7, 8), "float64") +####################################################################### +# ClipByValue +# ----------- + +def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype): + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue") + np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0') + +def test_forward_clip_by_value(): + '''test ClipByValue op''' + if tf.__version__ < LooseVersion('1.9'): + _test_forward_clip_by_value((4,), .1, 5., 'float32') + _test_forward_clip_by_value((4, 4), 1, 5, 'int32') + ####################################################################### # Multi Input to graph # -------------------- @@ -1591,6 +1608,14 @@ def test_forward_log(): tf.log(in_data, name="log") compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') +def test_forward_negative(): + """test tf operator Neg """ + np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data") + tf.negative(in_data, name="negative") + compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0') + def test_forward_softplus(): """test operator Softplus""" np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) @@ -1738,6 +1763,7 @@ def test_placeholder(): test_forward_unstack() test_forward_tile() test_forward_top_k_v2() + test_forward_clip_by_value() # Activations test_forward_sigmoid() @@ -1753,6 +1779,7 @@ def test_placeholder(): test_forward_pow_exp() test_forward_sign() test_forward_log() + test_forward_negative() test_forward_softplus() test_forward_sqrt() test_forward_rsqrt() @@ -1802,4 +1829,4 @@ def test_placeholder(): test_where() test_forward_matmul() - # TODO missing tests: rank, range \ No newline at end of file + # TODO missing tests: rank, range From 9ab290f8ca21521211a61a6a8d0ebcfbaf82cbfd Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 13 Jun 2019 09:21:19 +0800 Subject: [PATCH 125/176] Support export ADT value in Python (#3299) * Support export ADT value in Python * Cache original functions * Cleanup * Cleanup --- include/tvm/relay/interpreter.h | 13 ++++-- python/tvm/relay/backend/interpreter.py | 4 +- python/tvm/relay/backend/vm.py | 1 - python/tvm/relay/prelude.py | 1 - python/tvm/relay/testing/nat.py | 12 +++--- src/relay/backend/interpreter.cc | 17 ++++---- src/relay/backend/vm/compiler.cc | 41 ++++++------------- src/relay/backend/vm/vm.cc | 23 ++++------- src/relay/pass/pass_manager.cc | 3 +- tests/python/relay/test_adt.py | 11 +++-- .../python/relay/test_backend_interpreter.py | 12 +++--- .../relay/test_pass_to_a_normal_form.py | 4 +- tests/python/relay/test_vm.py | 8 ++-- 13 files changed, 69 insertions(+), 81 deletions(-) diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 15c96bb12822..68b7ccab99c7 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,17 +182,22 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - Constructor constructor; + int tag; tvm::Array fields; + /*! \brief Optional field tracking ADT constructor. */ + Constructor constructor; + void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("constructor", &constructor); + v->Visit("tag", &tag); v->Visit("fields", &fields); + v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(Constructor constructor, - tvm::Array fields); + TVM_DLL static ConstructorValue make(int tag, + tvm::Array fields, + Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 593cf7cfbdf7..ea25b970f87f 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -73,9 +73,9 @@ class Closure(Value): @register_relay_node class ConstructorValue(Value): - def __init__(self, constructor, fields, types): + def __init__(self, tag, fields, constructor, types): self.__init_handle_by_constructor__( - _make.ConstructorValue, constructor, fields, types) + _make.ConstructorValue, tag, fields, constructor, types) @register_relay_node diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 3b9946a3958d..4cb3d611abd4 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -97,7 +97,6 @@ def _eval_vm(mod, ctx, *args): args: List[tvm.NDArray, np.ndarray] The arguments to evaluate. """ - mod = optimize(mod) args = list(args) assert isinstance(args, list) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index c801e490d4cf..da75b9d00e13 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -491,7 +491,6 @@ def load_prelude(self): def __init__(self, mod): self.mod = mod self.load_prelude() - self.define_list_adt() self.define_list_hd() self.define_list_tl() diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index 4c0c87ce8a9e..a76a340f113d 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -151,16 +151,16 @@ def add_nat_definitions(prelude): # helper functions for working with nats -def count(n): +def count(prelude, n): """Takes a ConstructorValue corresponding to a nat ADT and converts it into a Python integer. This is an example of using an ADT value in Python. """ assert isinstance(n, ConstructorValue) - if n.constructor.name_hint == 'z': + if n.tag == prelude.z.tag: return 0 - assert n.constructor.name_hint == 's' - return 1 + count(n.fields[0]) + assert n.tag == prelude.s.tag + return 1 + count(prelude, n.fields[0]) def make_nat_value(prelude, n): @@ -168,8 +168,8 @@ def make_nat_value(prelude, n): constructs a ConstructorValue representing that value as a nat. """ if n == 0: - return ConstructorValue(prelude.z, [], []) - return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], []) + return ConstructorValue(prelude.z.tag, [], None, []) + return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, []) def make_nat_expr(prelude, n): diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d700c2036e21..1cc81d5174a5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(Constructor constructor, - tvm::Array fields) { +ConstructorValue ConstructorValueNode::make(int tag, + tvm::Array fields, + Constructor constructor) { NodePtr n = make_node(); - n->constructor = constructor; + n->tag = tag; n->fields = fields; + n->constructor = constructor; return ConstructorValue(n); } @@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorValueNode* node, tvm::IRPrinter* p) { - p->stream << "ConstructorValueNode(" << node->constructor + p->stream << "ConstructorValueNode(" << node->tag << "," << node->fields << ")"; }); @@ -448,7 +450,7 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConstructorValueNode::make(GetRef(con), args); + return ConstructorValueNode::make(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); @@ -544,9 +546,8 @@ class Interpreter : const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); - CHECK_NE(cvn->constructor->tag, -1); - if (op->constructor->tag == cvn->constructor->tag) { - // todo(M.K.): should use ptr equality but it is broken + CHECK_NE(cvn->tag, -1); + if (op->constructor->tag == cvn->tag) { CHECK_EQ(op->patterns.size(), cvn->fields.size()); for (size_t i = 0; i < op->patterns.size(); ++i) { if (!VisitPattern(op->patterns[i], cvn->fields[i])) { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 07633fc346ec..9b4ab6b8f6c8 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,6 +80,8 @@ struct VMCompilerContext { ConstTensorShapeMap const_tensor_shape_map; // List of lowered functions std::vector lowered_funcs; + // The functions that have been lowered. + std::unordered_map seen_funcs; }; // Compute the constant pool, i.e a mapping from Constant node to constant index. @@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor { size_t registers_num; CompileEngine engine; - /*! \brief The functions that have been lowered. */ - std::unordered_map seen_funcs; - /*! \brief Global shared meta data */ VMCompilerContext* context; @@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor { void VisitExpr_(const MatchNode* match_node) { auto match = GetRef(match_node); - LOG(FATAL) << "translation of match nodes to the VM is" + LOG(FATAL) << "translation of match nodes to the VM is " << "currently unsupported" << std::endl; } @@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor { } void VisitExpr_(const GlobalVarNode* gvar) { - LOG(FATAL) << "Global variables should only appear in the call position"; + // TODO(wweic): Support Load GlobalVar into a register + LOG(FATAL) << "Loading GlobalVar into register is not yet supported"; } void VisitExpr_(const IfNode* if_node) { @@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor { // TODO(jroesch): support lowered funcs for multiple targets CHECK_EQ(cfunc->funcs.size(), 1); auto op_index = -1; - if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) { + if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) { op_index = this->context->lowered_funcs.size(); this->context->lowered_funcs.push_back(cfunc->funcs[0]); - seen_funcs[cfunc->funcs[0]] = op_index; + this->context->seen_funcs[cfunc->funcs[0]] = op_index; } else { - op_index = seen_funcs[cfunc->funcs[0]]; + op_index = this->context->seen_funcs[cfunc->funcs[0]]; } Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); @@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor { std::vector args_registers; for (auto arg : call_node->args) { - CHECK(arg.as()) << "found: " << AsText(arg, false) << std::endl << arg; this->VisitExpr(arg); args_registers.push_back(last_register); } @@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor { auto func = this->context->module->Lookup(global); if (IsClosure(func)) { auto arity = func->params.size(); - std::vector free_var_registers; - for (size_t i = 0; i < arity; ++i) { - free_var_registers.push_back(var_register_map.at(func->params[i])); - } - Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister())); + Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); } else { Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); } } else if (auto constructor_node = op.as()) { auto constructor = GetRef(constructor_node); - auto tag = GetConstructorTag(constructor); - Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister())); + Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers, + NewRegister())); } else if (auto var_node = op.as()) { VisitExpr(GetRef(var_node)); Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister())); @@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor { } } - size_t GetConstructorTag(tvm::relay::Constructor constructor) { - auto it = this->context->tag_map.find(constructor); - if (it != this->context->tag_map.end()) { - return it->second; - } else { - auto tag = this->context->tag_map.size(); - this->context->tag_map[constructor] = tag; - this->context->tag_index_map[tag] = constructor; - return tag; - } - } - void VisitExpr_(const FunctionNode* func_node) { if (!func_node->IsPrimitive()) { LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl @@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, } VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { - DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl; + DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl; size_t params = func->params.size(); VMCompiler compiler(context); compiler.Compile(func); diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc index 34d067b9c68c..cf0b952005fc 100644 --- a/src/relay/backend/vm/vm.cc +++ b/src/relay/backend/vm/vm.cc @@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector ctxs, return res; } -Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) { - CHECK(module.defined() && type.defined()); +Value VMToValue(const relay::Module& module, Object obj) { + CHECK(module.defined()); switch (obj->tag) { case ObjectTag::kTensor: { - CHECK(type.as()) << "VM internal error: return value must be a tensor"; return TensorValueNode::make(ToNDArray(obj)); } case ObjectTag::kDatatype: { - // const auto* tuple_type - // const auto& data_type = obj.AsDatatype(); + const auto& data_type = obj.AsDatatype(); - // tvm::Array fields; - // for (size_t i = 0; i < data_type->fields.size(); ++i) { - // fields.push_back(VMToValue(tag_index_map, data_type->fields[i])); - // } + tvm::Array fields; + for (size_t i = 0; i < data_type->fields.size(); ++i) { + fields.push_back(VMToValue(module, data_type->fields[i])); + } - // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields); - LOG(FATAL) << "fix me"; + return ConstructorValueNode::make(data_type->tag, fields); } default: LOG(FATAL) << "unsupported return value of type: " << obj->tag; @@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue LOG(FATAL) << "expected function or module"; } - auto return_type = module->Lookup(module->entry_func)->ret_type; - std::vector vm_args; for (auto i = 3; i < args.size(); i++) { Object obj = args[i]; @@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue auto result = EvaluateModule(module, {ctx}, vm_args); DLOG(INFO) << "Evaluate VM returning: result=" << result->tag; - *ret = VMToValue(module, return_type, result); + *ret = VMToValue(module, result); }); } // namespace vm diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 500bdce742a0..fa79a5e82f9e 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -316,7 +316,8 @@ Module FunctionPassNode::operator()(const Module& mod, Module updated_mod = mod; // Execute the pass function and return a new module. std::vector > updates; - for (const auto& it : mod->functions) { + auto original = mod->functions; + for (const auto& it : original) { auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, updated_mod, pass_ctx); diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 77f4ab1f16a0..f3a08a869841 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -21,12 +21,15 @@ from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay import testing, create_executor from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr +from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) +def count(e): + return count_(p, e) + ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") @@ -91,18 +94,18 @@ def to_list(l): val = l ret = [] while True: - if val.constructor.name_hint == 'cons': + if val.tag == p.cons.tag: ret.append(val.fields[0]) val = val.fields[1] else: - assert val.constructor.name_hint == 'nil' + assert val.tag == p.nil.tag break return ret def tree_to_dict(t): assert isinstance(t, ConstructorValue) ret = {} - assert t.constructor.name_hint == 'rose' + assert t.tag == p.rose.tag ret['member'] = t.fields[0] ret['children'] = [] for subtree in to_list(t.fields[1]): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 1e5e2310e927..11ce11e48322 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple(): prelude = relay.prelude.Prelude(mod) intrp = create_executor("debug", mod) - nil_value = ConstructorValue(prelude.nil, [], []) - cons_value = ConstructorValue(prelude.cons, [ + nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, []) + cons_value = ConstructorValue(prelude.cons.tag, [ TensorValue(np.random.rand(1, 10).astype('float32')), nil_value - ], [relay.TensorType((1, 10), 'float32')]) + ], prelude.cons, [relay.TensorType((1, 10), 'float32')]) ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) tuple_value = TupleValue(*[ @@ -197,16 +197,16 @@ def test_function_taking_adt_ref_tuple(): id_func = intrp.evaluate(prelude.id) res_nil = id_func(nil_value) - assert res_nil.constructor == nil_value.constructor + assert res_nil.tag == nil_value.tag assert len(res_nil.fields) == 0 res_cons = id_func(cons_value) - assert res_cons.constructor == cons_value.constructor + assert res_cons.tag == cons_value.tag assert len(res_cons.fields) == len(cons_value.fields) tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(), cons_value.fields[0].asnumpy()) assert isinstance(res_cons.fields[1], ConstructorValue) - assert res_cons.fields[1].constructor == prelude.nil + assert res_cons.fields[1].tag == prelude.nil.tag assert len(res_cons.fields[1].fields) == 0 res_ref = id_func(ref_value) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index f395580a3f84..db40c86d4b28 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -142,8 +142,8 @@ def test_nat_add(): ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) - assert count(intrp.evaluate(add(s(z()), s(z())))) == 2 - assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 + assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 + assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext() diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index d727e776cbcd..12e343be02ac 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -185,9 +185,7 @@ def test_tuple_second(): result = veval(f, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) -@nottest def test_list_constructor(): - # TODO(wweic): implement pattern match to support this test def to_list(o): if isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.data.asnumpy().tolist()] @@ -204,6 +202,11 @@ def to_list(o): cons = p.cons l = p.l + # remove all functions to not have pattern match to pass vm compilation + # TODO(wweic): remove the hack and implement pattern match + for v, _ in mod.functions.items(): + mod[v] = relay.const(0) + one2 = cons(relay.const(1), nil()) one3 = cons(relay.const(2), one2) one4 = cons(relay.const(3), one3) @@ -213,7 +216,6 @@ def to_list(o): result = veval(mod)() obj = to_list(result) - import pdb; pdb.set_trace() tvm.testing.assert_allclose(obj, np.array([3,2,1])) def test_let_tensor(): From 5bd12cce7338f90910d5db3c87680fd9c36b5f9f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 13 Jun 2019 08:51:58 -0700 Subject: [PATCH 126/176] [Team] Jian Weng -> Committer (#3359) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 5b5c6b745efb..64cda9d0e623 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -55,6 +55,7 @@ We do encourage everyone to work anything they are interested in. - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Leyuan Wang](https://github.com/Laurawly): @Laurawly: - topi - [Yao Wang](https://github.com/kevinthesun): @kevinthesun: - topi, vision +- [Jian Weng](https://github.com/were): @were: - hybrid script - [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, topi, relay From c154e6b320e0f737e6726ea771f4f5e3c30f2b27 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Thu, 13 Jun 2019 08:52:25 -0700 Subject: [PATCH 127/176] Update tflite schema version to 1.13 (#3356) --- docker/install/ubuntu_install_tflite.sh | 2 +- tests/python/frontend/tflite/test_forward.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index d70f9890053d..802fb3b87d8c 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -35,7 +35,7 @@ pip2 install flatbuffers # Setup tflite from schema mkdir tflite cd tflite -wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/contrib/lite/schema/schema.fbs +wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r1.13/tensorflow/lite/schema/schema.fbs flatc --python schema.fbs cat <setup.py diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 5c2e3afb5a0d..23d46974b243 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables -from tensorflow.contrib import lite as interpreter_wrapper +from tensorflow import lite as interpreter_wrapper import tvm.relay.testing.tf as tf_testing From aa91e52bcd022df0eeee5bdbb036c671268357fa Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 13 Jun 2019 08:57:43 -0700 Subject: [PATCH 128/176] [Relay][Transform] quantize opt passes to pass manager (#3289) --- python/tvm/relay/quantize/quantize.py | 173 ++++++++++---------------- src/relay/pass/pass_manager.cc | 1 + src/relay/pass/quantize.cc | 63 ++++++---- 3 files changed, 107 insertions(+), 130 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 66c35b66a498..992217cf7d07 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -21,7 +21,9 @@ from . import _quantize from .. import expr as _expr +from .. import module as _module from .. import ir_pass as _ir_pass +from .. import transform as _transform from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -178,26 +180,7 @@ def _set_conv_counter(n): CONV_COUNTER = n -def annotate(graph): - """Given a float32 graph, annotate will rewrite the graph - and return back a graph which simulates the error brought by - current quantization scheme. - - Parameters - --------- - graph: Function - The original graph - - Returns - ------- - ret: Function - The graph after annotation - """ - _set_conv_counter(0) # reset counter - return _quantize.annotate(graph) - - -def calibrate(graph, dataset=None): +def calibrate(graph, mod=None, ctx=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` operator. @@ -207,8 +190,11 @@ def calibrate(graph, dataset=None): graph: Function The simulation graph after annotation. - dataset: list of dict of Var -> NDArray - The calibration dataset. + mod: tvm.relay.Module + The module where calibration happens on. + + ctx: tvm.relay.PassContext + The pass context used for calibration. Returns ------- @@ -253,93 +239,52 @@ def _make_const(val): return _expr.bind(graph, const_params) -def realize(graph): - """The realize pass will transform the simulated quantized - graph, which computes with float32 actually, to a real low-bit - integer graph. It will replace the simulated_quantize with - several fine-grained operators like add, multiply, and shift - as more as possible for performance (fusion, etc.) - - Parameters - --------- - graph: Function - The simulated graph after calibrating. +def annotate(): + """Given a float32 graph, this pass will rewrite the graph and return + a graph which simulates the error brought by the current quantization + scheme. Returns ------- - ret: Function - The graph after realization + ret: tvm.relay.Pass + The registered pass for quantization annotation. """ - return _quantize.realize(graph) + return _quantize.QuantizeAnnotate() -def optimize(func, params=None): - """ Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and - "CanonicalizeOps" optimization before quantization. - - # TODO(zhiics) These passes are executed one by one so far. We need to - # move them to the pass manager. - - Parameters - --------- - func: tvm.relay.Function - The original Relay function to be optimized. - - params : dict of str to tvm.NDArray - Input parameters to the graph that do not change - during inference time. Used for constant folding. +def realize(): + """The realize pass will transform the simulated quantized graph, which + actually computes with float32, to a real low-bit integer graph. It will + replace the `simulated_quantize` with several fine-grained operators like + add, multiply, and shift as much as possible for better performance. Returns ------- - ret: tvm.relay.Function - The graph after quantization + ret: tvm.relay.Pass + The registered pass for quantization realization. """ + return _quantize.QuantizeRealize() - opt_passes = ["SimplifyInference", - "FoldScaleAxis", - "FoldConstant", - "CanonicalizeOps"] - if params: - name_dict = {} - for arg in func.params: - name = arg.name_hint - if name in name_dict: - name_dict[name] = None - else: - name_dict[name] = arg - bind_dict = {} - for k, v in params.items(): - if k not in name_dict: - continue - arg = name_dict[k] - if arg is None: - raise ValueError("Multiple args in the function have name %s" % k) - bind_dict[arg] = _expr.const(v) - func = _expr.bind(func, bind_dict) - - if "SimplifyInference" in opt_passes: - func = _ir_pass.infer_type(func) - func = _ir_pass.simplify_inference(func) - - if "FoldConstant" in opt_passes: - func = _ir_pass.fold_constant(func) - - if "FoldScaleAxis" in opt_passes: - func = _ir_pass.infer_type(func) - func = _ir_pass.backward_fold_scale_axis(func) - func = _ir_pass.infer_type(func) - func = _ir_pass.forward_fold_scale_axis(func) - func = _ir_pass.fold_constant(func) - - if "CanonicalizeOps" in opt_passes: - func = _ir_pass.infer_type(func) - func = _ir_pass.canonicalize_ops(func) - - if "FoldConstant" in opt_passes: - func = _ir_pass.fold_constant(func) - - return func +def _bind_params(func, params): + """Bind the params to the expression. + """ + name_dict = {} + for arg in func.params: + name = arg.name_hint + if name in name_dict: + name_dict[name] = None + else: + name_dict[name] = arg + bind_dict = {} + for k, v in params.items(): + if k not in name_dict: + continue + arg = name_dict[k] + if arg is None: + raise ValueError("Multiple args in the function have name %s" % k) + bind_dict[arg] = _expr.const(v) + return _expr.bind(func, bind_dict) def quantize(graph, params=None, dataset=None): @@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None): ret: Function The graph after quantization """ - # TODO(zhiics) Move this to the pass manager. - graph = optimize(graph, params) - - graph = annotate(graph) - graph = calibrate(graph, dataset) - graph = realize(graph) - graph = _ir_pass.fold_constant(graph) - return graph + if params: + graph = _bind_params(graph, params) + + mod = _module.Module.from_expr(graph) + # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and + # "CanonicalizeOps" optimization before quantization. + optimize = _transform.Sequential([_transform.SimplifyInference(), + _transform.FoldConstant(), + _transform.FoldScaleAxis(), + _transform.CanonicalizeOps(), + _transform.FoldConstant()]) + + calibrate_pass = _transform.function_pass(calibrate, opt_level=1, + name="QuantizeCalibrate") + _set_conv_counter(0) # reset counter + quantize_seq = _transform.Sequential([annotate(), + calibrate_pass, + realize(), + _transform.FoldConstant()]) + with _transform.PassContext(opt_level=3, + required_pass=["QuantizeAnnotate", + "QuantizeCalibrate", + "QuantizeRealize"]): + mod = optimize(mod) + mod = quantize_seq(mod) + return mod[mod.entry_func.name_hint] diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index fa79a5e82f9e..d63d9121fe27 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod, << pass_info->name << " with opt level: " << pass_info->opt_level; + Module updated_mod = mod; // Execute the pass function and return a new module. std::vector > updates; diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 3a2e54c8ad39..7b6c1ffc3ed0 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -43,6 +43,8 @@ namespace tvm { namespace relay { namespace quantize { +using namespace relay::transform; + /*! \brief Attribute for simulated quantize operator */ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { int kind; @@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") static_cast(args[1].operator int())); }); - -TVM_REGISTER_API("relay._quantize.annotate") -.set_body_typed([] (const Expr& expr) { - std::function fmulti_ref = [](const Expr& e) { - if (e->derived_from()) { - const auto* n = e.as(); - CHECK(n); - const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - Expr ret = (*f)(n->expr, static_cast(kQInput)); - return static_cast(QAnnotateExprNode::make(ret, kQInput)); - } - return e; - }; - return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref); -}); - - // ============= // realize pass @@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call, RELAY_REGISTER_OP("nn.avg_pool2d") .set_attr("FQRealizeRewrite", AvgPoolRealize); - -TVM_REGISTER_API("relay._quantize.realize") -.set_body_typed([](const Expr& e) { - Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr); - return ret; -}); - - // ============= // qconfig @@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope") TVM_REGISTER_API("relay._quantize._ExitQConfigScope") .set_body_typed(QConfig::ExitQConfigScope); +Pass QuantizeAnnotate() { + std::function fmulti_ref = [](const Expr& e) { + if (e->derived_from()) { + const auto* n = e.as(); + CHECK(n); + const PackedFunc* f = + runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + Expr ret = (*f)(n->expr, static_cast(kQInput)); + return static_cast(QAnnotateExprNode::make(ret, kQInput)); + } + return e; + }; + + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref)); + }; + return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); +} + +TVM_REGISTER_API("relay._quantize.QuantizeAnnotate") +.set_body_typed(QuantizeAnnotate); + +Pass QuantizeRealizePass() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); + }; + return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); +} + +TVM_REGISTER_API("relay._quantize.QuantizeRealize") +.set_body_typed(QuantizeRealizePass); + } // namespace quantize } // namespace relay } // namespace tvm From 5701d2e78dff96f3a39cd390a1ce0c54122aa38a Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 09:02:26 -0700 Subject: [PATCH 129/176] [Relay] Check match expressions for completeness (#3203) --- include/tvm/relay/pass.h | 31 +- python/tvm/relay/ir_pass.py | 18 ++ python/tvm/relay/prelude.py | 2 - src/relay/pass/match_exhaustion.cc | 250 ++++++++++++++++ src/relay/pass/type_infer.cc | 9 + .../python/relay/test_pass_unmatched_cases.py | 267 ++++++++++++++++++ 6 files changed, 574 insertions(+), 3 deletions(-) create mode 100644 src/relay/pass/match_exhaustion.cc create mode 100644 tests/python/relay/test_pass_unmatched_cases.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 81587339f2ad..977bb6793bb5 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -122,6 +122,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); */ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); +/*! + * \brief Compare two patterns for structural equivalence. + * + * This comparison operator respects scoping and compares + * patterns without regard to variable choice. + * + * For example: `A(x, _, y)` is equal to `A(z, _, a)`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param t1 The left hand pattern. + * \param t2 The right hand pattern. + * + * \return true if equal, otherwise false + */ +TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); + /*! * \brief Add abstraction over a function * @@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); TVM_DLL Expr ToGraphNormalForm(const Expr& e); /*! - * \brief Aggressive constant propagation/constant folding/inlining. + * \brief Finds cases that the given match expression does not catch, if any. + * + * \param match the match expression to test + * + * \param mod The module used for accessing global type var definitions, can be None. * + * \return Returns a list of cases (as patterns) that are not handled by the match + * expression. + */ +TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); + +/*! + * \brief Aggressive constant propagation/constant folding/inlining. * It will do as much computation in compile time as possible. * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * As a side effect, code size will explode. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ea34c6b1958b..8f1ceded76dd 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -652,3 +652,21 @@ def partial_evaluate(expr): The output expression. """ return _ir_pass.partial_evaluate(expr) + +def unmatched_cases(match, mod=None): + """ + Finds cases that the match expression does not catch, if any. + + Parameters + ---------- + match : tvm.relay.Match + The match expression + mod : Optional[tvm.relay.Module] + The module (defaults to an empty module) + + Returns + ------- + missing_patterns : [tvm.relay.Pattern] + Patterns that the match expression does not catch. + """ + return _ir_pass.unmatched_cases(match, mod) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index da75b9d00e13..17df61750afd 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -39,7 +39,6 @@ def define_list_adt(self): self.cons = Constructor("cons", [a, self.l(a)], self.l) self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) - def define_list_hd(self): """Defines a function to get the head of a list. Assume the list has at least one element. @@ -54,7 +53,6 @@ def define_list_hd(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) - def define_list_tl(self): """Defines a function to get the tail of a list. diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc new file mode 100644 index 000000000000..173d6eacf528 --- /dev/null +++ b/src/relay/pass/match_exhaustion.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file match_exhaustion.cc + * \brief Checking Relay match expression exhaustiveness. + * + * This file implements a function that checks whether a match + * expression is exhaustive, that is, whether a given match clause + * matches every possible case. This is important for ensuring + * code correctness, since hitting an unmatched case results in a + * dynamic error unless exhaustiveness is checked in advance. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Possible pattern match results */ +enum MatchResult : int { + kMatch = 0, // pattern matches + kClash = 1, // pattern conflicts + kUnspecified = 2, // ambiguous: candidate needs more constructors specified +}; + +class CandidateChecker : public PatternFunctor { + public: + explicit CandidateChecker() {} + + MatchResult Check(const Pattern& pat, const Pattern& candidate) { + return this->VisitPattern(pat, candidate); + } + + // for a constructor pattern, we must ensure that the candidate is + // a ConstructorPattern, that it has the same constructor, and + // that its fields match the subpatterns. + MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override { + auto* ctor_cand = cand.as(); + // attempting to match non-constructor to constructor pattern: need to specify + if (ctor_cand == nullptr) { + return MatchResult::kUnspecified; + } + + // check that constructors match + if (!op->constructor.same_as(ctor_cand->constructor)) { + return MatchResult::kClash; + } + + // now check that subpatterns match + CHECK(op->patterns.size() == ctor_cand->patterns.size()); + bool unspecified = false; + for (size_t i = 0; i < op->patterns.size(); i++) { + MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); + // if we have a clash anywhere, then we can return clash + if (submatch == MatchResult::kClash) { + return MatchResult::kClash; + } + if (submatch == MatchResult::kUnspecified) { + unspecified = true; + } + } + // only return unspecified if we have ruled out a clash + if (unspecified) { + return MatchResult::kUnspecified; + } + return MatchResult::kMatch; + } + + // wildcard and var patterns always match + MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { + return MatchResult::kMatch; + } + + MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override { + return MatchResult::kMatch; + } +}; + +// Returns list of arrays corresponding to Cartesian product of input list +Array> CartesianProduct(Array> fields) { + CHECK_NE(fields.size(), 0); + Array field_vals = fields[fields.size() - 1]; + Array> ret; + + // base case: this is the last field left + if (fields.size() == 1) { + for (auto val : field_vals) { + ret.push_back(Array{val}); + } + return ret; + } + + // if we have more fields left, get the sub-candidates by getting + // their cartesian product and appending the elements here onto those + Array> remaining_fields; + for (size_t i = 0; i < fields.size() - 1; i++) { + remaining_fields.push_back(fields[i]); + } + Array> candidates = CartesianProduct(remaining_fields); + for (auto val : field_vals) { + for (auto candidate : candidates) { + candidate.push_back(val); + ret.push_back(candidate); + } + } + return ret; +} + +// Expands all wildcards in the candidate pattern once, using the pattern +// to decide which constructors to insert. Returns a list of all possible expansions. +Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, + const Module& mod) { + auto ctor_cand = cand.as(); + PatternConstructor clause_ctor = Downcast(clause_pat); + auto gtv = Downcast(clause_ctor->constructor->belong_to); + + // for a wildcard node, create constructor nodes with wildcards for all args + if (!ctor_cand) { + TypeData td = mod->LookupDef(gtv); + // for each constructor add a candidate + Array ret; + for (auto constructor : td->constructors) { + Array args; + for (auto inp : constructor->inputs) { + args.push_back(PatternWildcardNode::make()); + } + ret.push_back(PatternConstructorNode::make(constructor, args)); + } + return ret; + } + + // for constructors, we will expand the wildcards in any field + // that is an ADT + Array> values_by_field; + for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { + auto* subpattern = clause_ctor->patterns[i].as(); + // for non-ADT fields, we can only have a wildcard for the value + if (!subpattern) { + values_by_field.push_back({PatternWildcardNode::make()}); + continue; + } + + // otherwise, recursively expand + values_by_field.push_back(ExpandWildcards(GetRef(subpattern), + ctor_cand->patterns[i], mod)); + } + + // generate new candidates using a cartesian product + auto all_subfields = CartesianProduct(values_by_field); + Array ret; + for (auto subfields : all_subfields) { + ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); + } + return ret; +} + +/*! + * \brief Finds cases that the match expression does not catch, if any. + * \return Returns a list of cases that are not handled by the match + * expression. + */ +Array UnmatchedCases(const Match& match, const Module& mod) { + /* algorithm: + * candidates = { Wildcard } + * while candidates not empty { + * cand = candidates.pop() + * for clause in clauses { + * if clause fails: next clause + * if clause matches candidate: next candidate + * if candidate is not specific enough: + * candidates += expand_possible_wildcards(cand) + * next candidate + * } + * failed_candidates += { cand } + * } + * return failed_candidates + */ + std::stack candidates; + candidates.push(PatternWildcardNode::make()); + CandidateChecker checker; + + Array failures; + + while (!candidates.empty()) { + Pattern cand = candidates.top(); + candidates.pop(); + + bool failure = true; + for (auto clause : match->clauses) { + // if the check fails, we move on to the next + MatchResult check = checker.Check(clause->lhs, cand); + if (check == MatchResult::kClash) { + continue; + } + + // either success or we need to generate more candidates; + // either way, we're done with this candidate + failure = false; + if (check == MatchResult::kUnspecified) { + auto new_candidates = ExpandWildcards(clause->lhs, cand, mod); + for (auto candidate : new_candidates) { + candidates.push(candidate); + } + } + break; + } + + if (failure) { + failures.push_back(cand); + } + } + + return failures; +} + +// expose for testing only +TVM_REGISTER_API("relay._ir_pass.unmatched_cases") +.set_body_typed(const Match&, + const Module&)>([](const Match& match, + const Module& mod_ref) { + Module call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = ModuleNode::make({}, {}); + } + return UnmatchedCases(match, call_mod); + }); +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3fde3c7e7b36..4b126e5299cf 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor, GetType(c->rhs), op->span); } + + // check completness + Match match = GetRef(op); + Array unmatched_cases = UnmatchedCases(match, this->mod_); + if (unmatched_cases.size() != 0) { + LOG(WARNING) << "Match clause " << match << " does not handle the following cases: " + << unmatched_cases; + } + return rtype; } diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py new file mode 100644 index 000000000000..4f2bb20ad7d6 --- /dev/null +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -0,0 +1,267 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from tvm.relay.prelude import Prelude +from tvm.relay.ir_pass import unmatched_cases + +def test_empty_match_block(): + # empty match block will not match anything, so it should return a wildcard pattern + v = relay.Var('v') + match = relay.Match(v, []) + + unmatched = unmatched_cases(match) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternWildcard) + + +def test_trivial_matches(): + # a match clause with a wildcard will match anything + v = relay.Var('v') + match = relay.Match(v, [ + relay.Clause(relay.PatternWildcard(), v) + ]) + assert len(unmatched_cases(match)) == 0 + + # same with a pattern var + w = relay.Var('w') + match = relay.Match(v, [ + relay.Clause(relay.PatternVar(w), w) + ]) + assert len(unmatched_cases(match)) == 0 + + +def test_single_constructor_adt(): + mod = relay.Module() + box = relay.GlobalTypeVar('box') + a = relay.TypeVar('a') + box_ctor = relay.Constructor('box', [a], box) + box_data = relay.TypeData(box, [a], [box_ctor]) + mod[box] = box_data + + v = relay.Var('v') + match = relay.Match(v, [ + relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v) + ]) + + # with one constructor, having one pattern constructor case is exhaustive + assert len(unmatched_cases(match, mod)) == 0 + + # this will be so if we nest the constructors too + nested_pattern = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, + [relay.PatternConstructor(box_ctor, + [relay.PatternConstructor( + box_ctor, + [relay.PatternWildcard()])])]), v) + ]) + assert len(unmatched_cases(nested_pattern, mod)) == 0 + + +def test_too_specific_match(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()])]), v) + ]) + + unmatched = unmatched_cases(match, mod) + + # will not match nil or a list of length 1 + nil_found = False + single_length_found = False + assert len(unmatched) == 2 + for case in unmatched: + assert isinstance(case, relay.PatternConstructor) + if case.constructor == p.nil: + nil_found = True + if case.constructor == p.cons: + assert isinstance(case.patterns[1], relay.PatternConstructor) + assert case.patterns[1].constructor == p.nil + single_length_found = True + assert nil_found and single_length_found + + # if we add a wildcard, this should work + new_match = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()])]), v), + relay.Clause(relay.PatternWildcard(), v) + ]) + assert len(unmatched_cases(new_match, mod)) == 0 + + +def test_multiple_constructor_clauses(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + # list of length exactly 1 + relay.Clause( + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, [])]), v), + # list of length exactly 2 + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, []) + ])]), v), + # empty list + relay.Clause( + relay.PatternConstructor(p.nil, []), v), + # list of length 2 or more + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()])]), v) + ]) + assert len(unmatched_cases(match, mod)) == 0 + + +def test_missing_in_the_middle(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + # list of length exactly 1 + relay.Clause( + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, [])]), v), + # empty list + relay.Clause( + relay.PatternConstructor(p.nil, []), v), + # list of length 3 or more + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), + relay.PatternWildcard()])])]), + v) + ]) + + # fails to match a list of length exactly two + unmatched = unmatched_cases(match, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == p.cons + assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor) + assert unmatched[0].patterns[1].constructor == p.cons + assert isinstance(unmatched[0].patterns[1].patterns[1], relay.PatternConstructor) + assert unmatched[0].patterns[1].patterns[1].constructor == p.nil + + +def test_mixed_adt_constructors(): + mod = relay.Module() + box = relay.GlobalTypeVar('box') + a = relay.TypeVar('a') + box_ctor = relay.Constructor('box', [a], box) + box_data = relay.TypeData(box, [a], [box_ctor]) + mod[box] = box_data + + p = Prelude(mod) + + v = relay.Var('v') + box_of_lists_inc = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, + [relay.PatternConstructor(p.cons, [ + relay.PatternWildcard(), relay.PatternWildcard()])]), v) + ]) + + # will fail to match a box containing an empty list + unmatched = unmatched_cases(box_of_lists_inc, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == box_ctor + assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil + + box_of_lists_comp = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, [relay.PatternConstructor(p.nil, [])]), v), + relay.Clause( + relay.PatternConstructor( + box_ctor, [relay.PatternConstructor(p.cons, [ + relay.PatternWildcard(), relay.PatternWildcard()])]), v) + ]) + assert len(unmatched_cases(box_of_lists_comp, mod)) == 0 + + list_of_boxes_inc = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternWildcard()]), v) + ]) + + # fails to match empty list of boxes + unmatched = unmatched_cases(list_of_boxes_inc, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == p.nil + + list_of_boxes_comp = relay.Match(v, [ + # exactly one box + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.nil, [])]), v), + # exactly two boxes + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.cons, [ + relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.nil, []) + ])]), v), + # exactly three boxes + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.cons, [ + relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.cons, [ + relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.nil, []) + ])])]), v), + # one or more boxes + relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()]), v), + # no boxes + relay.Clause(relay.PatternConstructor(p.nil, []), v) + ]) + assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0 From ddbe01406353e79a7bc2851acc1e55939b30bcfe Mon Sep 17 00:00:00 2001 From: Hua Date: Thu, 13 Jun 2019 11:10:07 -0700 Subject: [PATCH 130/176] [Relay] Add Elemwise operator Sub, Divide, Power, Max, Min to tflite frontend. (#3357) --- python/tvm/relay/frontend/tflite.py | 22 ++++++ tests/python/frontend/tflite/test_forward.py | 70 ++++++++++++++------ 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index eb9e742ff85f..3b27428537e9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -64,7 +64,12 @@ def __init__(self, model, subgraph, exp_tab): 'MAX_POOL_2D': self.convert_max_pool2d, 'CONCATENATION': self.convert_concatenation, 'ADD': self.convert_add, + 'SUB': self.convert_sub, 'MUL': self.convert_mul, + 'DIV': self.convert_div, + 'POW': self.convert_pow, + 'MAXIMUM': self.convert_maximum, + 'MINIMUM': self.convert_minimum, 'FULLY_CONNECTED': self.convert_fully_connected, 'PAD': self.convert_pad, 'LOGISTIC': self.convert_logistic, @@ -320,10 +325,27 @@ def convert_add(self, op): """Convert TFLite ADD""" return self._convert_elemwise(_op.add, op) + def convert_sub(self, op): + """Convert TFLite SUB""" + return self._convert_elemwise(_op.subtract, op) + def convert_mul(self, op): """Convert TFLite MUL""" return self._convert_elemwise(_op.multiply, op) + def convert_div(self, op): + """Convert TFLite DIV""" + return self._convert_elemwise(_op.divide, op) + + def convert_pow(self, op): + return self._convert_elemwise(_op.power, op) + + def convert_maximum(self, op): + return self._convert_elemwise(_op.maximum, op) + + def convert_minimum(self, op): + return self._convert_elemwise(_op.minimum, op) + def convert_fully_connected(self, op): """Convert TFLite fully connected""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 23d46974b243..549855f0cfb5 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -320,7 +320,7 @@ def _test_elemwise(math_op, data): with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] - out = math_ops.add(in_data[0], in_data[1]) + out = math_op(in_data[0], in_data[1]) compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) # Test with tensor and constant @@ -338,35 +338,66 @@ def _test_add(data): """ One iteration of add """ return _test_elemwise(math_ops.add, data) +####################################################################### +# Subtract +# --- -def test_forward_add(): - """ Add """ - _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), - np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) - _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), - np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) - _test_add([np.arange(3.0, dtype=np.float32).reshape((1, 3)), - np.arange(3.0, dtype=np.float32).reshape((1, 3))]) - - +def _test_sub(data): + """ One iteration of subtract """ + return _test_elemwise(math_ops.subtract, data) ####################################################################### # Mul # --- - def _test_mul(data): """ One iteration of mul """ return _test_elemwise(math_ops.multiply, data) +####################################################################### +# Divide +# --- + +def _test_div(data): + """ One iteration of divide """ + return _test_elemwise(math_ops.divide, data) +####################################################################### +# Power +# --- + +def _test_pow(data): + """ One iteration of power """ + return _test_elemwise(math_ops.pow, data) +####################################################################### +# Maximum +# --- + +def _test_maximum(data): + """ One iteration of maximum """ + return _test_elemwise(math_ops.maximum, data) +####################################################################### +# Minimum +# --- + +def _test_minimum(data): + """ One iteration of minimum """ + return _test_elemwise(math_ops.minimum, data) -def test_forward_mul(): - """ Mul """ - _test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), +def _test_forward_elemwise(testop): + """ Elewise""" + testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) - _test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), + testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) - _test_mul([np.arange(3.0, dtype=np.float32).reshape((1, 3)), + testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)), np.arange(3.0, dtype=np.float32).reshape((1, 3))]) +def test_all_elemwise(): + _test_forward_elemwise(_test_add) + _test_forward_elemwise(_test_sub) + _test_forward_elemwise(_test_mul) + _test_forward_elemwise(_test_div) + _test_forward_elemwise(_test_pow) + _test_forward_elemwise(_test_maximum) + _test_forward_elemwise(_test_minimum) ####################################################################### # Squeeze @@ -584,9 +615,8 @@ def test_forward_inception_v4_net(): test_forward_softmax() test_forward_fully_connected() - # Math - test_forward_add() - test_forward_mul() + # Elemwise + test_all_elemwise() # End to End test_forward_mobilenet_v1() From 08f9f018bf01cf5b2e1eadc49438732f391430c2 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 13 Jun 2019 13:08:48 -0700 Subject: [PATCH 131/176] [Relay][Frontend] Add a bunch of ops in tf converter (#3270) --- python/tvm/relay/frontend/tensorflow.py | 53 +++- .../frontend/tensorflow/test_forward.py | 274 +++++++++++++++++- 2 files changed, 305 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ba076cc2819f..7319d5eb4a7e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -777,12 +777,12 @@ def _impl(inputs, attr, params): ignores=['name', 'Tidx'])([inputs[0]], attr) return _impl -def _reduce_all(): +def _reduce(op): def _impl(inputs, attr, params): axis = params.pop(inputs[1].name_hint).asnumpy() axis = tuple(axis) return AttrCvt( - op_name='all', + op_name=op, extras={'axis': axis}, transforms={'keep_dims':'keepdims'}, ignores=['name', 'Tidx'])([inputs[0]], attr) @@ -807,6 +807,14 @@ def _impl(inputs, attr, params): 'Taxis', '_class'])(new_input, attr) return _impl +def _gather_nd(): + """GatherNd""" + def _impl(inputs, attr, params): + return AttrCvt(op_name="gather_nd", + ignores=['Tindices', 'Tparams',\ + 'Taxis', '_class'])(inputs, attr) + return _impl + def _stridedSlice(): def _impl(inputs, attr, params): """Strided Slice. @@ -971,15 +979,18 @@ def _impl(inputs, attr, params): def _range(): def _impl(inputs, attr, params): - start = _get_num_param(params, inputs[0]) - limit = _get_num_param(params, inputs[1]) - delta = _get_num_param(params, inputs[2]) - - name = attr["_node_name"] - params[name] = tvm.nd.array([start, limit, delta]) - return [_expr.var(name, - shape=params[name].shape, - dtype='int32')] + start = params.pop(inputs[0].name_hint).asnumpy()[0] + limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ + if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] + delta = params.pop(inputs[2].name_hint).asnumpy()[0] + dtype = attr['dtype'].name if 'dtype' in attr else "int32" + return AttrCvt( + op_name="arange", + ignores=['Tidx'], + extras={'start': start, + "stop": limit, + 'step': delta, + 'dtype': dtype})([], attr) return _impl def _elu(): @@ -1099,6 +1110,13 @@ def _impl(inputs, attr, params): extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) return _impl +def _floordiv(): + def _impl(inputs, attr, params): + assert len(inputs) == 2 + div = AttrCvt('divide')(inputs, attr) + return _get_relay_op('floor')(div) + return _impl + def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) @@ -1207,8 +1225,9 @@ def _impl(inputs, attr, params): # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) _convert_map = { + 'Abs' : AttrCvt('abs'), 'Add' : _elemwise('add'), - 'All' : _reduce_all(), + 'All' : _reduce('all'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), @@ -1232,26 +1251,33 @@ def _impl(inputs, attr, params): 'ExpandDims' : _expand_dims(), 'Fill' : _fill(), 'Floor' : AttrCvt('floor'), + 'FloorDiv' : _floordiv(), 'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNormV2' : _fused_batch_norm(), 'Gather' : _gather(), + 'GatherNd' : _gather_nd(), 'GatherV2' : _gather(), 'Greater' : _broadcast('greater'), 'GreaterEqual' : _broadcast('greater_equal'), 'Identity' : _identity(), 'LeakyRelu' : AttrCvt('leaky_relu'), + 'LeftShift' : AttrCvt('left_shift'), 'Less' : _broadcast('less'), 'LessEqual' : _broadcast('less_equal'), 'Log' : AttrCvt('log'), 'LogicalAnd' : _logical('logical_and'), 'LogicalOr' : _logical('logical_or'), 'LogicalNot' : _logical('logical_not'), + 'LogSoftmax' : AttrCvt('log_softmax'), 'LRN' : _lrn(), 'MatMul' : _matmul(), + 'Max' : _reduce('max'), 'MaxPool' : _pooling('max_pool'), 'Maximum' : _elemwise('maximum'), 'Mean' : _mean(), + 'Min' : _reduce('min'), 'Minimum' : _elemwise('minimum'), + 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), @@ -1269,6 +1295,7 @@ def _impl(inputs, attr, params): 'ResizeBilinear' : _resize_bilinear(), 'ResizeBicubic' : _resize_bilinear(), 'ReverseV2' : _reverse_v2(), + 'RightShift' : AttrCvt('right_shift'), 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), 'Select' : _where(), @@ -1292,7 +1319,9 @@ def _impl(inputs, attr, params): 'Tile' : _tile(), 'TopKV2' : _topk(), 'Transpose' : _transpose(), + 'TruncateMod' : _elemwise('mod'), 'Unpack' : _unpack(), + 'ZerosLike' : AttrCvt('zeros_like'), } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 498c4735a9e8..6fc825a8924c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -64,6 +64,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, layout=layout, shape=shape_dict, outputs=out_names) + with relay.build_config(opt_level=opt_level): graph, lib, params = relay.build(sym, target, target_host, params) @@ -642,10 +643,53 @@ def test_forward_stridedslice(): 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=8) +####################################################################### +# FloorDiv, RealDiv +# ----------------- + +def _test_forward_divide(ip_shape, dtype): + np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype) + tf.reset_default_graph() + numerator = tf.placeholder(dtype, ip_shape, name="numer") + denominator = tf.placeholder(dtype, ip_shape, name="denomin") + tf.math.divide(numerator, denominator, name='RealDiv') + compare_tf_with_tvm([np_numer, np_denomin], ['numer:0', 'denomin:0'], 'RealDiv:0') + +def _test_forward_floordiv(ip_shape, dtype): + np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + tf.reset_default_graph() + numerator = tf.placeholder(dtype, ip_shape, name="numer") + tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv') + compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0') + +def test_forward_divide(): + '''test FloorDiv, RealDiv''' + _test_forward_divide((4,), 'int32') + _test_forward_divide((4, 3, 7), 'float32') + _test_forward_floordiv((4, 3, 7), 'float32') + ####################################################################### -# Gather, GatherV2 -# ---------------- +# TruncateMod +# ----------- +def _test_forward_truncatemod(ip_shape, dtype): + np_data_1 = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + np_data_2 = np.random.uniform(1, 10, size=ip_shape).astype(dtype) + tf.reset_default_graph() + in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1") + in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2") + tf.truncatemod(in_data_1, in_data_2, name='truncatemod') + compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'truncatemod:0') + +def test_forward_truncatemod(): + '''test TruncateMod''' + _test_forward_truncatemod((4, 3, 7), 'int32') + + +####################################################################### +# Gather, GatherV2, GatherNd +# -------------------------- def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): """ One iteration of a GatherV2 """ @@ -718,6 +762,33 @@ def test_forward_gather_v1(): _test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32') +def test_forward_gather_nd(): + """test operator GatherNd""" + np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (2, 2), name="in_data") + tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd") + compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0') + + +####################################################################### +# BiasAdd +# ------- +def test_forward_bias_add(): + """test Op BiasAdd""" + def check_bias_add(lh_shpae, rh_shape, dtype): + tf.reset_default_graph() + lh_data = np.random.uniform(size=lh_shpae).astype(dtype) + rh_data = np.random.uniform(size=rh_shape).astype(dtype) + lft_data = tf.placeholder(dtype, name="lft_data") + rgt_data = tf.placeholder(dtype, name="rgt_data") + tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'BiasAdd:0') + + check_bias_add((10, 8, 16, 32), (32,), dtype="int32") + check_bias_add((10, 20), (20,), dtype="float32") + + ####################################################################### # Split # ----- @@ -1109,6 +1180,32 @@ def test_forward_pack(): _test_pack(axis, [3]) _test_pack(0, []) + +####################################################################### +# Unpack +# ------ +def _test_forward_unpack(in_shape, axis, dtype): + """test operator Unpack""" + np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.unstack(in_data, axis=axis, name="Unpack") + compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0') + +def test_forward_unpack(): + _test_forward_unpack((3,), 0, 'int32') + _test_forward_unpack((3,), -1, 'int16') + _test_forward_unpack((21, 23, 3), 2, 'float32') + +####################################################################### +# Range +# ----- +def test_forward_range(): + """test operator Range""" + tf.reset_default_graph() + tf.range(1, 18, 3, name="range") + compare_tf_with_tvm([], [], 'range:0') + ####################################################################### # Pad # --- @@ -1182,7 +1279,7 @@ def test_forward_logical(): ####################################################################### # Where, Select # ------------- -def test_where(): +def test_forward_where(): ''' Where: return elements depending on conditions''' with tf.Graph().as_default(): with tf.Session() as sess: @@ -1553,6 +1650,22 @@ def test_forward_tanh(): tf.nn.tanh(in1) compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0') + +####################################################################### +# Softmax +# ------- +def test_forward_softmax(): + """test operator Softmax """ + def check_softmax(in_shape, axis, dtype): + np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.nn.softmax(in_data, axis=axis, name="Softmax") + compare_tf_with_tvm([np_data], ['in_data:0'], 'Softmax:0') + check_softmax((2, 3, 5), 2, "float32") + check_softmax((2, 3, 5), -1, "float32") + + ####################################################################### # Tensor # ------ @@ -1565,6 +1678,29 @@ def test_forward_round(): tf.round(in_data, name="round") compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0') +def test_forward_abs(): + """test operator Abs""" + np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (9, 11), name="in_data") + tf.math.abs(in_data, name="abs") + compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0') + +def _test_forward_zeros_like(in_shape, dtype): + np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.zeros_like(in_data, name="zeros_like") + compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0') + +def test_forward_zeros_like(): + if tf.__version__ < LooseVersion('1.2'): + _test_forward_zeros_like((2, 3), "int32") + _test_forward_zeros_like((2, 3, 5), "int8") + _test_forward_zeros_like((2, 3, 5, 7), "uint16") + _test_forward_zeros_like((2, 3, 11), "float32") + _test_forward_zeros_like((2, 3, 11), "float64") + def _test_forward_reverse_v2(in_shape, axis, dtype): np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) tf.reset_default_graph() @@ -1588,6 +1724,14 @@ def test_forward_sign(): tf.sign(in_data, name="sign") compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0') +def test_forward_square(): + """test operator Square """ + np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") + tf.square(in_data, name="square") + compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0') + def test_forward_pow_exp(): """test Pow and Exp """ np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32) @@ -1616,6 +1760,14 @@ def test_forward_negative(): tf.negative(in_data, name="negative") compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0') +def test_forward_log_softmax(): + """test operator LogSoftmax""" + np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (9, 11), name="in_data") + tf.math.log_softmax(in_data, name="LogSoftmax") + compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0') + def test_forward_softplus(): """test operator Softplus""" np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) @@ -1640,6 +1792,34 @@ def test_forward_sqrt(): tf.sqrt(in_data, name="sqrt") compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0') +def _test_forward_right_shift(in_shape, dtype): + """test operator RightShift""" + lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype) + rh_data = np.random.randint(1, 8, size=in_shape).astype(dtype) + tf.reset_default_graph() + lft_data = tf.placeholder(dtype, in_shape, name="lft_data") + rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data") + tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'RightShift:0') + +def test_forward_right_shift(): + _test_forward_right_shift((7,), 'int32') + _test_forward_right_shift((3, 11), 'int16') + +def _test_forward_left_shift(in_shape, dtype): + """test operator LeftShift""" + lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype) + rh_data = np.random.randint(1, 3, size=in_shape).astype(dtype) + tf.reset_default_graph() + lft_data = tf.placeholder(dtype, in_shape, name="lft_data") + rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data") + tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'LeftShift:0') + +def test_forward_left_shift(): + _test_forward_left_shift((10,), 'int32') + _test_forward_left_shift((224, 224, 3), 'int16') + ####################################################################### # Mean # ---- @@ -1652,13 +1832,13 @@ def check_mean(ishape, **kwargs): compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True) check_mean((10, 8, 16, 32)) - check_mean((10, 8, 16, 32), axis=(2,3)) - check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) + check_mean((10, 8, 16, 32), axis=(2, 3)) + check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True) ####################################################################### -# All -# --- -def test_forward_all(): +# All, Max, Min +# ------------- +def test_forward_reduce_all(): """Test the All operator.""" np_data = np.random.choice([True, False], size=(5, 7, 11)) tf.reset_default_graph() @@ -1666,6 +1846,30 @@ def test_forward_all(): tf.reduce_all(in_data, name="all") compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') +def test_forward_reduce_max(): + def check_max(ishape, axis, keepdims, dtype): + tf.reset_default_graph() + np_data = np.random.uniform(size=ishape).astype(dtype) + in_data = tf.placeholder(dtype, name="in_data") + tf.math.reduce_max(in_data, axis=axis, keepdims=keepdims, name="reduce_max") + compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0') + + check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32") + check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32") + check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32') + +def test_forward_reduce_min(): + def check_min(ishape, axis, keepdims, dtype): + tf.reset_default_graph() + np_data = np.random.uniform(size=ishape).astype(dtype) + in_data = tf.placeholder(dtype, name="in_data") + tf.math.reduce_min(in_data, axis=axis, keepdims=keepdims, name="reduce_max") + compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0') + + check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32") + check_min((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32") + check_min((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32') + ####################################################################### # Relational operators # -------------------- @@ -1723,6 +1927,38 @@ def test_forward_reduce_prod(): _test_forward_reduce_prod((5, 5), 1, True) +####################################################################### +# Maximum, Minimum +# ---------------- +def test_forward_maximum(): + """test Op Maximum""" + def check_maximum(lh_shape, rh_shape, dtype): + tf.reset_default_graph() + lh_data = np.random.uniform(size=lh_shape).astype(dtype) + rh_data = np.random.uniform(size=rh_shape).astype(dtype) + lft_data = tf.placeholder(dtype, name="lft_data") + rgt_data = tf.placeholder(dtype, name="rgt_data") + tf.math.maximum(lft_data, rgt_data, name="maximum") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'maximum:0') + + check_maximum((10, 8, 16, 32), (1,), dtype="int32") + check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") + +def test_forward_minimum(): + """test Op Minimum""" + def check_minimum(lh_shape, rh_shape, dtype): + tf.reset_default_graph() + lh_data = np.random.uniform(size=lh_shape).astype(dtype) + rh_data = np.random.uniform(size=rh_shape).astype(dtype) + lft_data = tf.placeholder(dtype, name="lft_data") + rgt_data = tf.placeholder(dtype, name="rgt_data") + tf.math.minimum(lft_data, rgt_data, name="minimum") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'minimum:0') + + check_minimum((10, 8, 16, 32), (1,), dtype="int32") + check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") + + ####################################################################### # PlaceholderWithDefault # ---------------------- @@ -1740,6 +1976,7 @@ def test_placeholder(): compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) + ####################################################################### # Main # ---- @@ -1756,14 +1993,22 @@ def test_placeholder(): test_forward_fill() test_forward_crop() test_forward_pad() + test_forward_unpack() test_forward_gather() test_forward_gather_v1() + test_forward_gather_nd() test_forward_stridedslice() test_forward_split() test_forward_unstack() test_forward_tile() test_forward_top_k_v2() test_forward_clip_by_value() + test_forward_maximum() + test_forward_minimum() + test_forward_range() + test_forward_right_shift() + test_forward_left_shift() + test_forward_truncatemod() # Activations test_forward_sigmoid() @@ -1780,17 +2025,26 @@ def test_placeholder(): test_forward_sign() test_forward_log() test_forward_negative() + test_forward_divide() + test_forward_abs() test_forward_softplus() test_forward_sqrt() test_forward_rsqrt() test_forward_expand_dims() + test_forward_square() + test_forward_softmax() + test_forward_log_softmax() + test_forward_bias_add() + test_forward_zeros_like() # Reductions test_forward_argminmax() test_forward_reduce() test_forward_mean() test_forward_reduce_prod() - test_forward_all() + test_forward_reduce_all() + test_forward_reduce_max() + test_forward_reduce_min() # General test_forward_multi_input() @@ -1826,7 +2080,7 @@ def test_placeholder(): # Relational ops test_forward_rel_ops() test_forward_logical() - test_where() + test_forward_where() test_forward_matmul() # TODO missing tests: rank, range From 827a8fb1c8c8ad751e09a9679709b8c935f6661e Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 13 Jun 2019 13:09:58 -0700 Subject: [PATCH 132/176] [ARITH] Revamp IntSet (#3272) --- include/tvm/arithmetic.h | 194 ++-- python/tvm/arith.py | 43 +- src/api/api_arith.cc | 5 + src/arithmetic/analyzer.cc | 5 +- src/arithmetic/bound_deducer.cc | 8 +- src/arithmetic/canonical_simplify.cc | 6 +- src/arithmetic/compute_expr.h | 10 +- src/arithmetic/const_fold.h | 58 +- src/arithmetic/detect_linear_equation.cc | 8 +- src/arithmetic/int_op_overflow.h | 4 +- src/arithmetic/int_set.cc | 1020 +++++++++-------- src/arithmetic/int_set.h | 143 +++ src/arithmetic/int_set_internal.h | 79 -- src/lang/expr_operator.cc | 19 +- src/pass/loop_partition.cc | 29 +- .../unittest/test_arith_deduce_bound.py | 168 +++ tests/python/unittest/test_arith_intset.py | 227 ++-- 17 files changed, 1184 insertions(+), 842 deletions(-) create mode 100644 src/arithmetic/int_set.h delete mode 100644 src/arithmetic/int_set_internal.h create mode 100644 tests/python/unittest/test_arith_deduce_bound.py diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 600e3c565358..c506268cb14b 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -328,71 +328,14 @@ class ConstraintContext { std::function exit_; }; -/*! - * \brief Analyzer that contains bunch of sub-analyzers. - * - * Each sub-analyzer can make use of another sub-analyzer - * by weak reference of this. - * - * NOTE for sub-analyzer developers: - * If the analyzer uses memoization, we need to clear the internal - * cache when information about a Var has been overrideen. - */ -class Analyzer { - public: - /*! \brief sub-analyzer: const integer bound */ - ConstIntBoundAnalyzer const_int_bound; - /*! \brief sub-analyzer: modular set */ - ModularSetAnalyzer modular_set; - /*! \brief sub-analyzer rewrite simplify */ - RewriteSimplifier rewrite_simplify; - /*! \brief sub-analyzer canonical simplify */ - CanonicalSimplifier canonical_simplify; - /*! \brief constructor */ - Analyzer(); - /*! - * \brief Notify all the sub-analyzers that var - * is created and binded to expr. - * - * Each var can only be binded once. - * - * \param var The variable. - * \param expr The expression we bind to. - */ - void Bind(const VarExpr& var, const Expr& expr); - /*! - * \brief Notify all the sub-analyzers that var - * is created and binded to a range. - * - * Each var can only be binded once. - * - * \param var The variable. - * \param range The range we bind to. - */ - void Bind(const VarExpr& var, const Range& range); - /*! - * \brief Whether can we proof expr >= val. - - * Non-negative proof is very useful in integer analysis - * to lower divisions and mods given difference in trunc and ceil mode. - * - * \param expr The expression. - * \param lower_bound The lower bound. - * \return Whether we can proof it. - * - * \note Analyzer will call into sub-analyzers to get the result. - */ - bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); -}; - //----------------------------------------------- -// Integer set abstraction API. +// Integer set data structure. // // This is a API build on top of the base // integer analysis API to provide set analysis. //------------------------------------------------ /*! - * \brief Sign of an expression or set. + * \brief Sign type of an integer expression. */ enum SignType { kPositive, @@ -401,8 +344,13 @@ enum SignType { kUnknown }; -// internal node container of int set. -struct IntSetNode; +/*! + * \brief Base class of all IntSet containers. + */ +struct IntSetNode : public Node { + static constexpr const char* _type_key = "IntSet"; + TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); +}; /*! * \brief Integer set class, represent a set of integers in one dimension. @@ -424,11 +372,6 @@ class IntSet : public NodeRef { * \return The covering range. */ Range cover_range(Range max_range) const; - /*! - * \brief find an interval that covers the set. - * \return The covering interval set. - */ - IntSet cover_interval() const; /*! \return Lower bound of the set */ Expr min() const; /*! \return upper bound of the set */ @@ -493,33 +436,91 @@ class IntSet : public NodeRef { }; /*! - * \brief Base class of all IntSet containers. + * \brief Integer set analyzer. */ -struct IntSetNode : public Node { - static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); +class IntSetAnalyzer { + public: + /*! + * \brief Find a symbolic integer set that contains all possible values of + * expr given the domain of each variables. + * + * \param expr The expression of interest. + * \param dom_map The domain map to indicate which variable to relax. + * \return the result of the analysis. + */ + IntSet operator()(const Expr& expr, const Map& dom_map); + + private: + friend class Analyzer; + explicit IntSetAnalyzer(Analyzer* parent); + ~IntSetAnalyzer(); + class Impl; + /*! \brief Internal impl */ + Impl* impl_; }; /*! - * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] - * Where coeff[i] and base are invariant of var[j] for all i and j. + * \brief Analyzer that contains bunch of sub-analyzers. * - * \param e The expression to be detected. - * \param vars List of variables to be used in detection. - * \return [coeff[i]] if it is possible, empty array if it is not. - */ -Array DetectLinearEquation(const Expr& e, const Array& vars); - -/*! - * \brief Detect if expression corresponds to clip bound of the vars + * Each sub-analyzer can make use of another sub-analyzer + * by weak reference of this. * - * \param e The expression to be detected. - * \param vars List of variables to be used in detection. - * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value - * return empty if the e does not match the pattern. + * NOTE for sub-analyzer developers: + * If the analyzer uses memoization, we need to clear the internal + * cache when information about a Var has been overridden. */ -Array DetectClipBound(const Expr& e, const Array& vars); +class Analyzer { + public: + /*! \brief sub-analyzer: const integer bound */ + ConstIntBoundAnalyzer const_int_bound; + /*! \brief sub-analyzer: modular set */ + ModularSetAnalyzer modular_set; + /*! \brief sub-analyzer rewrite simplify */ + RewriteSimplifier rewrite_simplify; + /*! \brief sub-analyzer canonical simplify */ + CanonicalSimplifier canonical_simplify; + /*! \brief sub-analyzer: int set */ + IntSetAnalyzer int_set; + /*! \brief constructor */ + Analyzer(); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to expr. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param expr The expression we bind to. + */ + void Bind(const VarExpr& var, const Expr& expr); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to a range. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const VarExpr& var, const Range& range); + /*! + * \brief Whether can we prove expr >= val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param lower_bound The lower bound. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); +}; +//----------------------------------------------- +// Integer set legacy API. +//------------------------------------------------ /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond, */ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); +// Expression pattern detector. +/*! + * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] + * Where coeff[i] and base are invariant of var[j] for all i and j. + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return [coeff[i]] if it is possible, empty array if it is not. + */ +Array DetectLinearEquation(const Expr& e, + const Array& vars); + +/*! + * \brief Detect if expression corresponds to clip bound of the vars + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value + * return empty if the e does not match the pattern. + */ +Array DetectClipBound(const Expr& e, + const Array& vars); + // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/python/tvm/arith.py b/python/tvm/arith.py index eda5cb825326..4c3c05f75796 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -32,21 +32,21 @@ def is_everything(self): return _api_internal._IntSetIsEverything(self) -@register_node +@register_node("arith.IntervalSet") class IntervalSet(IntSet): - """Represent set of continuous interval""" - def min(self): - """get the minimum value""" - return _api_internal._IntervalSetGetMin(self) - - def max(self): - """get the maximum value""" - return _api_internal._IntervalSetGetMax(self) + """Represent set of continuous interval [min_value, max_value] + Parameters + ---------- + min_value : Expr + The minimum value in the interval. -@register_node -class StrideSet(IntSet): - """Represent set of strided integers""" + max_value : Expr + The maximum value in the interval. + """ + def __init__(self, min_value, max_value): + self.__init_handle_by_constructor__( + _make_IntervalSet, min_value, max_value) @register_node("arith.ModularSet") @@ -114,6 +114,7 @@ def __init__(self): self._modular_set = _mod("modular_set") self._rewrite_simplify = _mod("rewrite_simplify") self._canonical_simplify = _mod("canonical_simplify") + self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") def const_int_bound(self, expr): @@ -176,6 +177,24 @@ def canonical_simplify(self, expr): """ return self._canonical_simplify(expr) + def int_set(self, expr, dom_map): + """Compute a symbolic IntSet that covers expr for all values in dom_map. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + dom_map : Dict[Var, tvm.arith.IntSet] + The domain for variables to be relaxed. + + Returns + ------- + result : IntSet + The result. + """ + return self._int_set(expr, dom_map) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 4d5d8bdf58d3..f31f02b1eaf4 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector") TVM_REGISTER_API("arith.intset_interval") .set_body_typed(IntSet::interval); + TVM_REGISTER_API("arith.DetectLinearEquation") .set_body_typed(DetectLinearEquation); @@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { *ret = self->canonical_simplify(args[0]); }); + } else if (name == "int_set") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->int_set(args[0], args[1]); + }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { auto& sptr = args[1].node_sptr(); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index bd8c7005f458..10a1c7f041c3 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -31,7 +31,8 @@ Analyzer::Analyzer() : const_int_bound(this), modular_set(this), rewrite_simplify(this), - canonical_simplify(this) { + canonical_simplify(this), + int_set(this) { } void Analyzer::Bind(const VarExpr& v, const Expr& expr) { @@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() { bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { if (const auto* ptr = expr.as()) { - return ptr->value > lower_bound; + return ptr->value >= lower_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); if (bd->min_value >= lower_bound) return true; diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 89e556c6f75f..395a371f43af 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,12 +30,12 @@ #include #include +#include "int_set.h" namespace tvm { namespace arith { using namespace ir; -using HalideIR::Internal::Interval; // a visitor to find the path to the target variable // from a expression. @@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e, BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success) return IntSet::nothing(); - Expr min = Interval::neg_inf, max = Interval::pos_inf; + Expr min = neg_inf(), max = pos_inf(); if (d.is_greater) { min = d.result; } else { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 1bf1f84fb635..a50cbfb96591 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file canonical_simplify.cc * \brief Canonical form based simplification. */ @@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) { if (TryCompare(temp, cval) == kLT) { return temp; } else { - return SplitModConst(ToSplitExpr(temp), cval); + // contonue to use logic below. + a = extra; + psum = a.as(); + CHECK(psum != nullptr); } } } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index ff2fb8dbd4ac..cc54bff596be 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,8 +27,8 @@ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #include -#include #include +#include namespace tvm { namespace arith { @@ -105,12 +105,12 @@ inline Expr ComputeExpr(Expr a, Expr b) { template<> inline Expr ComputeExpr(Expr a, Expr b) { - return HalideIR::Internal::Interval::make_max(a, b); + return max(a, b); } template<> inline Expr ComputeExpr(Expr a, Expr b) { - return HalideIR::Internal::Interval::make_min(a, b); + return min(a, b); } template diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index fbf8fe7e6f89..ec50aef5c51e 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -206,6 +206,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); }); + if (a.same_as(b)) return a; return Expr(); } @@ -216,6 +217,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); }); + if (a.same_as(b)) return a; return Expr(); } @@ -307,6 +309,58 @@ inline Expr TryConstFold(Expr a) { return Expr(); } +/*! \brief Helper namespace for symbolic value limits */ +struct SymbolicLimits { + /*! \brief positive infinity */ + static Expr pos_inf_; + /*! \brief negative infinity */ + static Expr neg_inf_; +}; + +/*! + * \brief Opaque expression representing positive infinity. + * + * It can can only be used as parameter of by min/max + * for integer analysis and cannot be used in normal expressions. + * + * \return positive infinity. + */ +inline Expr pos_inf() { + return SymbolicLimits::pos_inf_; +} + +/*! + * \brief Check if value is positive infinity. + * \param value The value to be checked. + * + * \return The check result. + */ +inline bool is_pos_inf(const Expr& value) { + return value.same_as(SymbolicLimits::pos_inf_); +} + +/*! + * \brief Opaque expression representing negative infinity. + * + * It can can only be used as parameter of by min/max + * for integer analysis and cannot be used in normal expressions. + * + * \return negative infinity. + */ +inline Expr neg_inf() { + return SymbolicLimits::neg_inf_; +} + +/*! + * \brief Check if value is negative infinity. + * \param value The value to be checked. + * + * \return The check result. + */ +inline bool is_neg_inf(const Expr& value) { + return value.same_as(SymbolicLimits::neg_inf_); +} + } // namespace arith } // namespace tvm #endif // TVM_ARITHMETIC_CONST_FOLD_H_ diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 2fe21fef7e21..e584c8b1ce33 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,8 +19,8 @@ /*! * Copyright (c) 2017 by Contributors - * \file bound_deducer.cc - * \brief Utility to deduce bound of expression + * \file detect_linear_equation.cc + * \brief Utility to detect patterns in the expression. */ #include #include diff --git a/src/arithmetic/int_op_overflow.h b/src/arithmetic/int_op_overflow.h index 87f4f059e858..b78f21cb1dba 100644 --- a/src/arithmetic/int_op_overflow.h +++ b/src/arithmetic/int_op_overflow.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index abbb7cd9744e..75a4aaf83ab6 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,201 +18,55 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file int_set.cc * \brief The integer set functions */ #include -#include -#include #include -#include +#include +#include +#include #include -#include "compute_expr.h" -#include "int_set_internal.h" +#include "int_set.h" +#include "pattern_match.h" namespace tvm { namespace arith { -using HalideIR::Internal::Interval; -using namespace ir; - -inline IntSet IntSet::cover_interval() const { - if ((*this).as()) return *this; - const StrideSet* s = (*this).as(); - if (s) { - CHECK_NE(s->extents.size(), 0U); - Expr max = s->base.max; - for (size_t i = 0; i < s->extents.size(); ++i) { - max = max + s->extents[i] * s->strides[i] - s->strides[i]; - } - return IntervalSet::make(s->base.min, Simplify(max)); - } - LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval"; - return IntSet::everything(); -} - -Range IntSet::cover_range(Range max_range) const { - IntSet temp; - const IntervalSet* s_int = (*this).as(); - if (s_int == nullptr) { - temp = this->cover_interval(); - s_int = temp.as(); - } - if (s_int->i.is_bounded()) { - return Range::make_by_min_extent( - s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min)); - } - return max_range; -} - -Expr IntSet::min() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int); - return s_int->i.min; -} - -Expr IntSet::max() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int); - return s_int->i.max; -} - -bool IntSet::is_nothing() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_empty()); -} - -bool IntSet::is_everything() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_everything()); -} +Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle()); +Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle()); -bool IntSet::is_single_point() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_single_point()); +IntervalSet::IntervalSet(Expr min_value, Expr max_value) { + auto node = make_node(); + node->min_value = std::move(min_value); + node->max_value = std::move(max_value); + node_ = std::move(node); } -bool IntSet::can_prove_positive() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && is_positive_const(ir::Simplify(s_int->i.min))); +IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { + return IntervalSet(min_value, max_value); } -bool IntSet::can_prove_negative() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); -} +TVM_REGISTER_API("arith._make_IntervalSet") +.set_body_typed(MakeIntervalSet); -bool IntSet::can_prove_non_positive() const { - if (const IntervalSet* s_int = (*this).as()) { - auto max = ir::Simplify(s_int->i.max); - return is_zero(max) || is_negative_const(max); - } - return false; -} -bool IntSet::can_prove_non_negative() const { - if (const IntervalSet* s_int = (*this).as()) { - // Any reason why we should or should not use can_prove() to implement - // these functions? - auto min = ir::Simplify(s_int->i.min); - return is_zero(min) || is_positive_const(min); - } - return false; -} - - -SignType IntSet::sign_type() const { - if (can_prove_positive()) { - return kPositive; - } else if (can_prove_negative()) { - return kNegative; - } else if (is_single_point() && is_zero(point_value())) { - return kZero; +IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + Expr max_value = min(a->max_value, b->max_value); + Expr min_value = max(a->min_value, b->min_value); + if ((max_value.type().is_int() || max_value.type().is_uint()) && + (min_value.type().is_int() || min_value.type().is_uint()) && + analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { + return IntervalSet::Empty(); } else { - return kUnknown; - } -} -Expr IntSet::point_value() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int && s_int->i.is_single_point()); - return s_int->i.min; -} - -IntSet IntSet::nothing() { - return IntervalSet::make(Interval::nothing()); -} - -IntSet IntSet::everything() { - return IntervalSet::make(Interval::everything()); -} - -IntSet IntSet::single_point(Expr x) { - return IntervalSet::make(Interval::single_point(x)); -} - -IntSet IntSet::range(Range r) { - // must make sure it can be matched back by MatchRange. - if (is_one(r->extent)) { - return IntSet::single_point(r->min); - } - if (is_positive_const(r->extent) && is_const(r->min)) { - return IntervalSet::make( - r->min, ComputeExpr(ComputeExpr(r->extent, r->min), 1)); - } - return IntervalSet::make(r->min, (r->extent + r->min) - 1); -} - -IntSet IntSet::interval(Expr min, Expr max) { - if (min.same_as(max)) { - return IntSet::single_point(min); - } - return IntervalSet::make(min, max); -} - -inline bool prove_equal(Expr lhs, Expr rhs) { - return is_zero(ir::Simplify(lhs - rhs)); -} - -// Check if a is created from b. -bool IntSet::match_range(const Range& b) const { - const IntSet& a = *this; - const IntervalSet* a_int = a.as(); - if (!a_int) return false; - const Interval& i = a_int->i; - return prove_equal(i.min, b->min) && - prove_equal(i.max, ComputeExpr(ComputeExpr(b->extent, b->min), 1)); -} - -inline bool MatchPoint(const IntSet& a, - const Expr& b) { - const IntervalSet* a_int = a.as(); - if (!a_int) return false; - const Interval& i = a_int->i; - return i.is_single_point() && i.min.same_as(b); -} - -IntSet Union(const Array& sets) { - if (sets.size() == 0) return IntSet::nothing(); - if (sets.size() == 1) return sets[0]; - Interval x = sets[0].cover_interval().as()->i; - for (size_t i = 1; i < sets.size(); ++i) { - IntSet s = sets[i].cover_interval(); - const Interval& y = s.as()->i; - x.include(y); + return IntervalSet(min_value, max_value); } - x.max = ir::Simplify(x.max); - x.min = ir::Simplify(x.min); - return IntervalSet::make(x); } -IntSet Intersect(const Array& sets) { - Interval x = sets[0].cover_interval().as()->i; - for (size_t i = 1; i < sets.size(); ++i) { - Interval y = sets[i].cover_interval().as()->i; - x = Interval::make_intersection(x, y); - } - return IntervalSet::make(x); +IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + Expr max_value = max(a->max_value, b->max_value); + Expr min_value = min(a->min_value, b->min_value); + return IntervalSet(min_value, max_value); } // type traits @@ -227,407 +81,623 @@ struct is_logical_op { static const bool value = true; \ }; -// interval related. -template -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key; - return IntSet::everything(); +TVM_DECLARE_LOGICAL_OP(And); +TVM_DECLARE_LOGICAL_OP(Or); +TVM_DECLARE_LOGICAL_OP(EQ); +TVM_DECLARE_LOGICAL_OP(NE); +TVM_DECLARE_LOGICAL_OP(GE); +TVM_DECLARE_LOGICAL_OP(GT); +TVM_DECLARE_LOGICAL_OP(LE); +TVM_DECLARE_LOGICAL_OP(LT); +TVM_DECLARE_LOGICAL_OP(Not); + +/*! + * \brief Combine two interval set under arithmetic operations. + * \note this can possibly relax the set. + */ +template +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + Expr res = TryConstFold(a->min_value, b->min_value); + if (!res.defined()) res = Op::make(a->min_value, b->min_value); + return IntervalSet::SinglePoint(res); + } + if (is_logical_op::value) { + return IntervalSet(make_const(a->min_value.type(), 0), + make_const(a->min_value.type(), 1)); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (a->IsEverything()) return a; + if (b->IsEverything()) return b; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - Interval r = Interval::everything(); - if (a.has_lower_bound() && b.has_lower_bound()) { - r.min = ComputeExpr(a.min, b.min); - } - if (a.has_upper_bound() && b.has_upper_bound()) { - r.max = ComputeExpr(a.max, b.max); - } - return IntervalSet::make(r); +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value + b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + Expr min_value = + a->HasLowerBound() && b->HasLowerBound() ? + a->min_value + b->min_value : neg_inf(); + Expr max_value = + a->HasUpperBound() && b->HasUpperBound() ? + a->max_value + b->max_value : pos_inf(); + return IntervalSet(min_value, max_value); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value - b->min_value); } - Interval r = Interval::everything(); - if (a.has_lower_bound() && b.has_upper_bound()) { - r.min = ComputeExpr(a.min, b.max); - } - if (a.has_upper_bound() && b.has_lower_bound()) { - r.max = ComputeExpr(a.max, b.min); - } - return IntervalSet::make(r); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + Expr min_value = + a->HasLowerBound() && b->HasUpperBound() ? + a->min_value - b->max_value : neg_inf(); + Expr max_value = + a->HasUpperBound() && b->HasLowerBound() ? + a->max_value - b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); } + template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - if (a.is_single_point() && !b.is_single_point()) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value * b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (a->IsSinglePoint()) { std::swap(a, b); } - if (b.is_single_point()) { - if (is_zero(b.min)) return IntSet::single_point(0); - if (is_one(b.min)) return IntervalSet::make(a); - Expr e1 = a.has_lower_bound() ? ComputeExpr(a.min, b.min) : a.min; - Expr e2 = a.has_upper_bound() ? ComputeExpr(a.max, b.min) : a.max; - // no relaxation is needed in here due to set is inclusive - // TODO(tqchen): consider convert to StrideSet. - if (is_positive_const(b.min)) { - return IntervalSet::make(e1, e2); - } else if (is_negative_const(b.min)) { - return IntervalSet::make(e2, e1); - } else if (a.is_bounded()) { + if (b->IsSinglePoint()) { + if (is_zero(b->min_value)) return b; + if (is_one(b->min_value)) return a; + if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { + Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); + Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { + Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); + Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr cmp = b.min >= make_zero(b.min.type().element_of()); - return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1)); + Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr e1 = a->min_value * b->min_value; + Expr e2 = a->max_value * b->min_value; + return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); } } - LOG(WARNING) << "Return Everything in CombineInterval Mul"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Mul"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval
(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr
(a.min, b.min)); - } - if (b.is_single_point()) { - if (is_zero(b.min)) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value / b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (b->IsSinglePoint()) { + if (is_zero(b->min_value)) { LOG(FATAL) << "Divide by zero in CombineInterval Div"; } - if (is_one(b.min)) return IntervalSet::make(a); - Expr e1 = a.has_lower_bound() ? ComputeExpr
(a.min, b.min) : a.min; - Expr e2 = a.has_upper_bound() ? ComputeExpr
(a.max, b.min) : a.max; + if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive - if (is_positive_const(b.min)) { - return IntervalSet::make(e1, e2); - } else if (is_negative_const(b.min)) { - return IntervalSet::make(e2, e1); - } else if (a.is_bounded()) { + if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { + Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); + Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { + Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); + Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr cmp = b.min >= make_zero(b.min.type().element_of()); - return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1)); + Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr e1 = a->min_value / b->min_value; + Expr e2 = a->max_value / b->min_value; + return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); } } - LOG(WARNING) << "Return Everything in CombineInterval Div"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Div"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value % b->min_value); } - if (b.is_single_point()) { - Expr divisor = b.min; + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + + if (b->IsSinglePoint()) { + const Expr& divisor = b->min_value; if (is_zero(divisor)) { LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } - return IntervalSet::make(make_zero(divisor.type()), divisor - 1); + // We need to add more bound constraints throughout the code. + // The logic below assumes a is non-negative, which usually + // is the case of our application. + // TODO(tqchen): add bound constraints for a. + if (analyzer->CanProveGreaterEqual(divisor, 0)) { + return IntervalSet(make_zero(divisor.type()), divisor - 1); + } else { + Expr bound = abs(divisor) - 1; + return IntervalSet(-bound, bound); + } } - - LOG(WARNING) << "Return Everything in CombineInterval Mod"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Mod"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } - return IntervalSet::make(Interval::make_max(a.min, b.min), - Interval::make_max(a.max, b.max)); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + return IntervalSet(max(a->min_value, b->min_value), + max(a->max_value, b->max_value)); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } - return IntervalSet::make(Interval::make_min(a.min, b.min), - Interval::make_min(a.max, b.max)); -} - -template -inline IntSet CombineInterval_(IntSet a, IntSet b) { - return CombineInterval( - a.as()->i, b.as()->i); -} - -// stride related -inline IntSet AsStrideSet(IntSet a) { - if (a.as()) return a; - const IntervalSet* s = a.as(); - CHECK(s->i.is_bounded()); - NodePtr n = make_node(); - n->base = s->i; - return IntSet(n); -} -template -inline IntSet CombineSets(IntSet a, IntSet b) { - return CombineInterval_(a.cover_interval(), b.cover_interval()); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + return IntervalSet(min(a->min_value, b->min_value), + min(a->max_value, b->max_value)); } -template<> -inline IntSet CombineSets(IntSet a, IntSet b) { - const IntervalSet* a_int = a.as(); - const IntervalSet* b_int = b.as(); - if (a_int && is_zero(a_int->i.min)) return b; - if (b_int && is_zero(b_int->i.min)) return a; - a = AsStrideSet(a); - b = AsStrideSet(b); - const StrideSet* a_stride = a.as(); - const StrideSet* b_stride = b.as(); - auto n = make_node(*a_stride); - for (size_t i = 0; i < b_stride->extents.size(); ++i) { - n->extents.push_back(b_stride->extents[i]); - n->strides.push_back(b_stride->strides[i]); - } - n->base = CombineInterval( - a_stride->base, b_stride->base).as()->i; - return IntSet(n); -} - -inline IntSet NegateSet(IntSet a) { - const IntervalSet* a_int = a.as(); - if (a_int) { - if (a_int->i.is_single_point()) { - return IntSet::single_point(-a_int->i.min); - } else { - Interval r = Interval::everything(); - if (a_int->i.has_upper_bound()) { - r.min = -(a_int->i.max); - } - if (a_int->i.has_lower_bound()) { - r.max = -(a_int->i.min); - } - return IntervalSet::make(r); - } - } else { - return NegateSet(a.cover_interval()); +// internal helper function to get an interval set +IntervalSet ToIntervalSet(IntSet set) { + if (auto* node = set.as()) { + return GetRef(node); } + DLOG(INFO) << "cannot resolve int set " << set; + return IntervalSet::Everything(); } -template<> -inline IntSet CombineSets(IntSet a, IntSet b) { - return CombineSets(a, NegateSet(b)); -} - -TVM_DECLARE_LOGICAL_OP(And); -TVM_DECLARE_LOGICAL_OP(Or); -TVM_DECLARE_LOGICAL_OP(EQ); -TVM_DECLARE_LOGICAL_OP(NE); -TVM_DECLARE_LOGICAL_OP(GE); -TVM_DECLARE_LOGICAL_OP(GT); -TVM_DECLARE_LOGICAL_OP(LE); -TVM_DECLARE_LOGICAL_OP(LT); -TVM_DECLARE_LOGICAL_OP(Not); +using namespace ir; -// generic combine operations of two sets -template -inline IntSet Combine(const IntSet& a, const IntSet &b) { - if (is_logical_op::value) { - return IntervalSet::make(0, 1); +// Simplified version of int set evaluator that operates on IntervalSet +// We might use better set analysis in the future to replace the intervalset. +class IntervalSetEvaluator : + public ExprFunctor { + public: + IntervalSetEvaluator(Analyzer* analyzer, + const Map& dom_map, + bool eval_vec = false) + : analyzer_(analyzer), + dom_map_(dom_map), + eval_vec_(eval_vec) { } - const IntervalSet* a_int = a.as(); - const IntervalSet* b_int = b.as(); - if (a_int && a_int->i.is_everything()) return a; - if (b_int && b_int->i.is_everything()) return b; - if (a_int && b_int) { - return CombineInterval(a_int->i, b_int->i); + + IntervalSet Eval(const Expr& val) { + return this->VisitExpr(val); } - if (a_int && !(a_int->i.is_bounded())) { - return CombineInterval_(a, b.cover_interval()); + + IntervalSet VisitExpr_(const IntImm* op) final { + return IntervalSet::SinglePoint(GetRef(op)); } - if (b_int && !(b_int->i.is_bounded())) { - return CombineInterval_(a.cover_interval(), b); + + IntervalSet VisitExpr_(const UIntImm* op) final { + return IntervalSet::SinglePoint(GetRef(op)); } - return CombineSets(a, b); -} -class IntSetEvaluator : - public ExprFunctor { - public: - explicit IntSetEvaluator( - const std::unordered_map& dom_map, - bool eval_vec = false) - : dom_map_(dom_map), eval_vec_(eval_vec) {} - // Evaluate. - IntSet Eval(const Expr& e) { - return this->VisitExpr(e, e); - } - IntSet VisitExpr_(const IntImm* op, const Expr& e) final { - return IntSet::single_point(e); - } - IntSet VisitExpr_(const UIntImm* op, const Expr& e) final { - return IntSet::single_point(e); - } - IntSet VisitExpr_(const Variable* op, const Expr& e) final { - auto it = dom_map_.find(op); + IntervalSet VisitExpr_(const Variable* op) final { + Var var = GetRef(op); + auto it = dom_map_.find(var); if (it != dom_map_.end()) { - return it->second; + return ToIntervalSet((*it).second); } else { - return IntSet::single_point(e); + return IntervalSet::SinglePoint(var); } } - IntSet VisitExpr_(const Add* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Add* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Sub* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Sub* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Mul* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Mul* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Div* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Div* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Mod* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Mod* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Min* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Min* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Max* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Max* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const EQ* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const EQ* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const NE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const NE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const LT* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const LT* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const LE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const LE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const GT* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const GT* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const GE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const GE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const And* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const And* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Or* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Or* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Ramp* op, const Expr& e) final { + + IntervalSet VisitExpr_(const Ramp* op) final { CHECK(eval_vec_); - IntSet base = Eval(op->base); - int vstride; - if (GetConstInt(op->stride, &vstride)) { + IntervalSet base = Eval(op->base); + PVar stride; + if (stride.Match(op->stride)) { Type t = op->base.type(); - if (vstride > 0) { + int64_t vstride = stride.Eval()->value; + if (vstride> 0) { return Combine( + analyzer_, base, - IntSet::interval(make_zero(t), - make_const(t, vstride * op->lanes -1))); + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { return Combine( + analyzer_, base, - IntSet::interval(make_const(t, vstride * op->lanes + 1), - make_zero(t))); + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } - LOG(WARNING) << "cannot evaluate set on expression " << e; - return IntSet::everything(); + DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); + return IntervalSet::Everything(); } - IntSet VisitExpr_(const Broadcast* op, const Expr& e) final { + + IntervalSet VisitExpr_(const Broadcast* op) final { CHECK(eval_vec_); - return Eval(op->value); + return VisitExpr(op->value); } - IntSet VisitExpr_(const Select* op, const Expr& e) final { - IntSet true_set = this->Eval(op->true_value); - IntSet false_set = this->Eval(op->false_value); - return Union({false_set, true_set}); + + IntervalSet VisitExpr_(const Select* op) final { + IntervalSet true_set = this->Eval(op->true_value); + IntervalSet false_set = this->Eval(op->false_value); + return Union(analyzer_, false_set, true_set); } - IntSet VisitExprDefault_(const Node* op, const Expr& e) final { - LOG(WARNING) << "cannot evaluate set type " << e->type_key(); - return IntSet::everything(); + + IntervalSet VisitExprDefault_(const Node* op) final { + DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + return IntervalSet::Everything(); } private: + // whether set is exactly single point that equals value. + bool MatchPoint(const IntervalSet& set, + const Expr& value) const { + return set->min_value.same_as(value) && set->max_value.same_as(value); + } + template - inline IntSet Binary(const T* op, const Expr& e) { - IntSet a = this->Eval(op->a); - IntSet b = this->Eval(op->b); + inline IntervalSet VisitBinaryExpr_(const T* op) { + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntSet::single_point(e); + return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(a, b); + return Combine(analyzer_, a, b); } - const std::unordered_map& dom_map_; + Analyzer* analyzer_; + const Map& dom_map_; bool eval_vec_{false}; }; +class IntSetAnalyzer::Impl { + public: + explicit Impl(Analyzer* analyzer) + : analyzer_(analyzer) { + } + + IntSet Eval(const Expr& expr, const Map& dom_map) const { + return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); + } + + private: + Analyzer* analyzer_; +}; + +IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) + : impl_(new Impl(parent)) { +} + +IntSetAnalyzer::~IntSetAnalyzer() { + delete impl_; +} + +IntSet IntSetAnalyzer::operator()(const Expr& expr, + const Map& dom_map) { + return impl_->Eval(expr, dom_map); +} + +// Quickly adapt to IntSet interface +// TODO(tqchen): revisit IntSet interface as well. +Range IntSet::cover_range(Range max_range) const { + IntSet temp; + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int != nullptr); + if (s_int->HasUpperBound() && s_int->HasLowerBound()) { + return Range::make_by_min_extent( + s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value)); + } + return max_range; +} + +Expr IntSet::min() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int); + return s_int->min_value; +} + +Expr IntSet::max() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int); + return s_int->max_value; +} + +bool IntSet::is_nothing() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsEmpty()); +} + +bool IntSet::is_everything() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsEverything()); +} + +bool IntSet::is_single_point() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsSinglePoint()); +} + +bool IntSet::can_prove_positive() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && is_positive_const(ir::Simplify(s_int->min_value))); +} + +bool IntSet::can_prove_negative() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && is_negative_const(ir::Simplify(s_int->max_value))); +} + +bool IntSet::can_prove_non_positive() const { + if (const auto* s_int = (*this).as()) { + auto max = ir::Simplify(s_int->max_value); + return is_zero(max) || is_negative_const(max); + } + return false; +} + +bool IntSet::can_prove_non_negative() const { + if (const IntervalSetNode* s_int = (*this).as()) { + auto min = ir::Simplify(s_int->min_value); + return is_zero(min) || is_positive_const(min); + } + return false; +} + +SignType IntSet::sign_type() const { + if (can_prove_positive()) { + return kPositive; + } else if (can_prove_negative()) { + return kNegative; + } else if (is_single_point() && is_zero(point_value())) { + return kZero; + } else { + return kUnknown; + } +} +Expr IntSet::point_value() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int && s_int->IsSinglePoint()); + return s_int->min_value; +} + +IntSet IntSet::nothing() { + return IntervalSet::Empty(); +} + +IntSet IntSet::everything() { + return IntervalSet::Everything(); +} + +IntSet IntSet::single_point(Expr x) { + return IntervalSet::SinglePoint(x); +} + +IntSet IntSet::interval(Expr min, Expr max) { + if (min.same_as(max)) { + return IntSet::single_point(min); + } + return IntervalSet(min, max); +} + +// Range related code +inline bool ProveEqual(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + +IntSet IntSet::range(Range r) { + // must make sure it can be matched back by MatchRange. + if (is_one(r->extent)) { + return IntSet::single_point(r->min); + } + return IntervalSet(r->min, r->extent + r->min - 1); +} + +bool IntSet::match_range(const Range& b) const { + const IntSet& a = *this; + const IntervalSetNode* a_int = a.as(); + if (!a_int) return false; + return ProveEqual(a_int->min_value, b->min) && + ProveEqual(a_int->max_value, b->extent + b->min - 1); +} + +IntSet Union(const Array& sets) { + if (sets.size() == 0) return IntSet::nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer ana; + IntervalSet x = ToIntervalSet(sets[0]); + for (size_t i = 1; i < sets.size(); ++i) { + x = Union(&ana, x, ToIntervalSet(sets[i])); + } + return IntervalSet(ir::Simplify(x->min_value), + ir::Simplify(x->max_value)); +} + +IntSet Intersect(const Array& sets) { + if (sets.size() == 0) return IntSet::nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer ana; + IntervalSet x = ToIntervalSet(sets[0]); + for (size_t i = 1; i < sets.size(); ++i) { + x = Intersect(&ana, x, ToIntervalSet(sets[i])); + } + return IntervalSet(ir::Simplify(x->min_value), + ir::Simplify(x->max_value)); +} + +Map ConvertDomMap(const Map& dom_map) { + Map dmap; + for (auto kv : dom_map) { + dmap.Set(kv.first->var, kv.second); + } + return dmap; +} + +Map ConvertDomMap( + const std::unordered_map& dom_map) { + Map dmap; + for (auto kv : dom_map) { + dmap.Set(GetRef(kv.first), kv.second); + } + return dmap; +} + IntSet EvalSet(Expr e, - const std::unordered_map& dom_map) { - return IntSetEvaluator(dom_map, false).Eval(e); + const Map& dom_map) { + Analyzer ana; + return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } IntSet IntSet::vector(Expr x) { - std::unordered_map dmap; - return IntSetEvaluator(dmap, true).Eval(x); + Analyzer ana; + Map dmap; + return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } IntSet EvalSet(Expr e, const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - return EvalSet(e, dmap); + return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, +IntSet EvalSet(Expr e, const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - IntSet min_set = m.Eval(r->min).cover_interval(); + return EvalSet(e, ConvertDomMap(dom_map)); +} + +IntSet EvalSet(Range r, + const Map& dom_map) { + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + IntervalSet min_set = m.Eval(r->min); // Simplifying first can give tighter bounds if r->min and r->extent share variables - Expr sum = ComputeExpr(ComputeExpr(r->min, r->extent), 1); - IntSet max_set = m.Eval(Simplify(sum)).cover_interval(); - const Interval& ni = min_set.as()->i; - const Interval& xi = max_set.as()->i; - if (!ni.has_lower_bound()) return IntSet::everything(); - if (!xi.has_upper_bound()) return IntSet::everything(); - return IntervalSet::make(ni.min, xi.max); + Expr sum = r->min + r->extent - 1; + IntervalSet max_set = m.Eval(Simplify(sum)); + if (!min_set->HasLowerBound()) return IntSet::everything(); + if (!max_set->HasUpperBound()) return IntSet::everything(); + return IntervalSet(min_set->min_value, max_set->max_value); } -IntSet EvalSet(IntSet s, +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - s = s.cover_interval(); - const IntervalSet* s_int = s.as(); - Expr vmax = s_int->i.has_upper_bound() ? - m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max; - Expr vmin = s_int->i.has_lower_bound() ? - m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min; - return IntervalSet::make(vmin, vmax); + return EvalSet(r, ConvertDomMap(dom_map)); } -class SubExprIntSetEvaluator : public IntSetEvaluator { +IntSet EvalSet(IntSet s, + const std::unordered_map& dom_map) { + Analyzer ana; + auto dmap = ConvertDomMap(dom_map); + IntervalSetEvaluator m(&ana, dmap); + const IntervalSetNode* s_int = s.as(); + Expr vmax = s_int->HasUpperBound() ? + m.Eval(s_int->max_value).max() : s_int->max_value; + Expr vmin = s_int->HasLowerBound() ? + m.Eval(s_int->min_value).min() : s_int->min_value; + return IntervalSet(vmin, vmax); +} + +class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntSetEvaluator( - const std::unordered_map& dom_map) - : IntSetEvaluator(dom_map) {} + explicit SubExprIntervalSetEvaluator( + Analyzer* analyzer, + const Map& dom_map) + : IntervalSetEvaluator(analyzer, dom_map) {} - IntSet VisitExpr(const Expr& n, const Expr& e) final { - IntSet ret = IntSetEvaluator::VisitExpr(n, e); + IntervalSet VisitExpr(const Expr& n) final { + IntervalSet ret = IntervalSetEvaluator::VisitExpr(n); expr_map[n] = ret; return ret; } @@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator { ExprIntSetMap expr_map; }; -ExprIntSetMap EvalSetForEachSubExpr(Expr e, +ExprIntSetMap EvalSetForEachSubExpr( + Expr e, const std::unordered_map& dom_map) { - SubExprIntSetEvaluator m(dom_map); + Analyzer ana; + auto dmap = ConvertDomMap(dom_map); + SubExprIntervalSetEvaluator m(&ana, dmap); m.Eval(e); return m.expr_map; } IntSet EvalSet(Range r, const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - return EvalSet(r, dmap); + return EvalSet(r, ConvertDomMap(dom_map)); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const IntervalSet *op, IRPrinter *p) { - p->stream << "interval-set" - << "[" << op->i.min << ", " - << op->i.max << ']'; +.set_dispatch([](const IntervalSetNode *op, IRPrinter *p) { + p->stream << "IntervalSet" + << "[" << op->min_value << ", " + << op->max_value << ']'; }); - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h new file mode 100644 index 000000000000..bf7fec24f78a --- /dev/null +++ b/src/arithmetic/int_set.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file int_set.h + * \brief Internal data structure for integer set. + */ +#ifndef TVM_ARITHMETIC_INT_SET_H_ +#define TVM_ARITHMETIC_INT_SET_H_ + +#include +#include +#include +#include "const_fold.h" + +namespace tvm { +namespace arith { + +/*! + * \brief Symbolic interval set. + * + * \note We intentionally keep the internal of IntSet private, + as we might change it later. + */ +class IntervalSetNode : public IntSetNode { + public: + /*! \brief Minimum value in the interval. */ + Expr min_value; + /*! \brief Maximum value in the interval. */ + Expr max_value; + + // visitor overload. + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("min_value", &min_value); + v->Visit("max_value", &max_value); + } + + /*! \return Whether the interval has upper bound. */ + bool HasUpperBound() const { + return !is_pos_inf(max_value) && !IsEmpty(); + } + /*! \return Whether the interval has lower bound. */ + bool HasLowerBound() const { + return !is_neg_inf(min_value) && !IsEmpty(); + } + /*! \return Whether the interval is a single point. */ + bool IsSinglePoint() const { + return min_value.same_as(max_value); + } + /*! \return whether interval represent nothing */ + bool IsEmpty() const { + // during computations, either extreme could occur. + return is_pos_inf(min_value) || is_neg_inf(max_value); + } + /*! \return whether interval represent everything */ + bool IsEverything() const { + return is_neg_inf(min_value) && is_pos_inf(max_value); + } + + static constexpr const char* _type_key = "arith.IntervalSet"; + TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode); +}; + +/*! + * \brief Interval set used for symbolic integer analysis. + * \sa IntervalSetNode + */ +class IntervalSet : public IntSet { + public: + /*! + * \brief Make a new instance of interval set. + * \param min_value The minimum value in the interval. + * \param max_value The maximum value in the interval. + * \return The created set. + */ + TVM_DLL IntervalSet(Expr min_value, Expr max_value); + + /*! + * \brief Create an IntervalSet that represents a single point. + * \param value The value to be represented. + * \return The result set. + */ + static IntervalSet SinglePoint(Expr value) { + return IntervalSet(value, value); + } + /*! + * \brief Create an IntervalSet that represents everything. + * \param value The value to be represented. + * \return The result set. + */ + static IntervalSet Everything() { + return IntervalSet(neg_inf(), pos_inf()); + } + /*! + * \brief Create an empty eet. + * \return The result set. + */ + static IntervalSet Empty() { + return IntervalSet(pos_inf(), neg_inf()); + } + + TVM_DEFINE_NODE_REF_COW(IntervalSetNode); + TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); +}; + +/*! + * \brief Create union of two IntervalSets. + * \param analyzer The analyzer for simplification analysis. + * \param a The first set. + * \param b The second set. + * \return The result set. + */ +TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); + +/*! + * \brief Create insersection of two IntervalSets. + * \param analzyer The analyzer for simplification analysis. + * \param a The first set. + * \param b The second set. + * \return The result set. + */ +TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITHMETIC_INT_SET_H_ diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h deleted file mode 100644 index 8b675cfbffda..000000000000 --- a/src/arithmetic/int_set_internal.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2017 by Contributors - * \file int_set_internal.h - * \brief Implementations of integer set - */ -#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_ -#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_ - -#include -#include -#include - -namespace tvm { -namespace arith { - -using HalideIR::Internal::Interval; - -/*! \brief Set of continuous interval */ -struct IntervalSet : public IntSetNode { - /*! \brief the internal interval*/ - Interval i; - - static IntSet make(Interval i) { - NodePtr n = - make_node(); - n->i = i; - return IntSet(n); - } - static IntSet make(Expr min, Expr max) { - NodePtr n = - make_node(); - n->i.min = min; - n->i.max = max; - return IntSet(n); - } - - static constexpr const char* _type_key = "IntervalSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode); -}; - -/*! - * \brief set represented by strided integers - * Reserved for cases where strided access is supported. - */ -struct StrideSet : public IntSetNode { - /*! \brief the base inetrval */ - Interval base; - /*! \brief additional extents in positive number */ - Array extents; - /*! \brief additional strides in positive number */ - Array strides; - - static constexpr const char* _type_key = "StrideSet"; - TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); -}; - -} // namespace arith -} // namespace tvm - -#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_ diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 8537f17b763c..3f5254069b8d 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) { return ir::Mod::make(a, b); } + Expr min(Expr a, Expr b) { + // inf-aware simplificaiton + using arith::is_pos_inf; + using arith::is_neg_inf; + if (is_pos_inf(a)) return b; + if (is_neg_inf(a)) return a; + if (is_pos_inf(b)) return a; + if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) { } Expr max(Expr a, Expr b) { + // inf-aware simplificaiton + using arith::is_pos_inf; + using arith::is_neg_inf; + if (is_pos_inf(a)) return a; + if (is_neg_inf(a)) return b; + if (is_pos_inf(b)) return b; + if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index bcb2608682ee..0a5b7410f3cf 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,7 @@ #include #include #include -#include "../arithmetic/int_set_internal.h" +#include "../arithmetic/int_set.h" #include "../runtime/thread_storage_scope.h" namespace tvm { @@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator { std::pair> GetIntervalAndCondset(const Partition &partitions, - const arith::Interval &for_interval, + const arith::IntervalSet &for_interval, bool cond_value); inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); @@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator { /* Candidate IRs that may be partitioned potentially */ std::unordered_map hint_map_; std::unordered_map relax_map_; + arith::Analyzer analyzer_; CandidateSelector selector; }; @@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator { // given in the second component provably have value given by cond_value std::pair> LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, - const arith::Interval &for_interval, + const arith::IntervalSet &for_interval, bool cond_value) { Array sets; std::unordered_set cond_set; for (const auto &kv : partitions) { if (kv.first.second == cond_value) { - arith::Interval interval = kv.second.as()->i; - arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval); - if (!intersection.is_empty()) { + arith::IntervalSet interval = Downcast(kv.second); + arith::IntervalSet intersection = arith::Intersect( + &analyzer_, interval, for_interval); + if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); } @@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr max, Stmt body, bool partition_thread_scope) { + using namespace arith; PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); if (finder.partitions.empty()) return Stmt(); - arith::Interval for_interval(min, max); + arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; std::unordered_set cond_set; @@ -478,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // if such interval doesn't exist, find an interval in which all // conditions on var are false std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, false); + GetIntervalAndCondset(finder.partitions, for_interval, false); if (middle_interval.is_nothing()) // we couldn't find an interval in which the condintions are provably true or false // Therefore, we can't partition the loop based on those conds @@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, cond_value = true; } - arith::Interval middle_interval_i = middle_interval.as()->i; + IntervalSet middle_interval_i = Downcast(middle_interval); // middle_interval is the subrange of the loop variable range for which a // set of conditions are true (or false resp.) // The part of the loop variable range that is before (after resp.) that @@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr body_begin; Stmt pre_stmt; bool pre_stmt_recurse = true; - if (middle_interval_i.has_lower_bound()) { + if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); if (!can_prove(body_begin == min)) { Expr cond = (body_begin - min >= 0); @@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; - if (middle_interval_i.has_upper_bound()) { + if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); if (!can_prove(middle_interval.max() == max)) { // require the extent to be non-negative diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py new file mode 100644 index 000000000000..7fe6f56edea7 --- /dev/null +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + +def test_deduce(): + a = tvm.var('a') + b = tvm.var('b') + c = tvm.var('c') + d = tvm.var('d') + + b_s = tvm.arith.IntervalSet(2, 3) + c_s = tvm.arith.IntervalSet(10, 15) + d_s = tvm.arith.IntervalSet(-3, -1) + zero = tvm.const(0, "int32") + + e0 = (-b)*a+c-d + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) + ans0 = ((d - c) /(b*-1)) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + # expression containing variable a is on rhs + res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + e0 = d*a+c-d + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) + ans0 = ((0-c)/d + 1) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + # expression containing variable a is on rhs + res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + e1 = (a*4+b < c) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + ans1 = (((c - b) + -1)/4) + assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + + # expression containing variable a is on rhs + e1 = (c > a*4+b) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + + e2 = (tvm.max(5, a * 4) < 0) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" + + # expression containing variable a is on rhs + e2 = (zero < tvm.max(5, a * 4)) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" + + + e3 = (-b)+a*c-d + res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + ans3 = 2/c+1 + assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) + + res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) + +def test_check(): + a = tvm.var('a') + b = tvm.var('b') + c = tvm.var('c') + d = tvm.var('d') + + b_s = tvm.arith.IntervalSet(2, 3) + c_s = tvm.arith.IntervalSet(5, 7) + d_s = tvm.arith.IntervalSet(-3, -1) + + # no compare operator + res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) + assert res1.is_nothing() + + # multiple compare operators + res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) + assert res2.is_nothing() + + # multiple target variable + res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) + assert res2.is_nothing() + +def test_deduce_basic(): + def test_basic(a1, a2, coff): + a = tvm.var('a') + b = tvm.var('b') + b_s = tvm.arith.IntervalSet(a1, a2) + e0 = b + a*coff + 3 + + res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 + + res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 + + test_basic(0, 4, 4) + test_basic(1, 5, 4) + test_basic(2, 6, 4) + test_basic(0, 4, -4) + test_basic(1, 5, -4) + test_basic(2, 6, -4) + +def test_deduce_complex(): + def test_complex(a1, a2, coff): + a = tvm.var('a') + b = tvm.var('b') + b_s = tvm.arith.IntervalSet(a1, a2) + e0 = (b*3 + a* coff) * 4 + + res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 + + res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 + + test_complex(0, 4, 4) + test_complex(0, 4, -4) + test_complex(2, 6, 4) + test_complex(0, 4, -4) + test_complex(1, 5, -4) + test_complex(2, 6, -4) + + +if __name__ == "__main__": + test_check() + test_deduce_basic() + test_deduce_complex() diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index a74162ec07f2..fa14bcf48fdf 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,168 +16,87 @@ # under the License. import tvm + +class IntSetChecker: + def __init__(self): + self.analyzer = tvm.arith.Analyzer() + + def verify(self, data, dmap, expected): + res = self.analyzer.int_set(data, dmap) + def err_msg(): + return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected) + def equal(x, y): + res = self.analyzer.canonical_simplify(x - y) + return tvm.ir_pass.Equal(res, 0) + assert equal(res.min_value, expected[0]), err_msg() + assert equal(res.max_value, expected[1]), err_msg() + def test_basic(): - s = tvm.arith.intset_interval(2, 3) - assert s.min().value == 2 - assert s.max().value == 3 + s = tvm.arith.IntervalSet(2, 3) + assert s.min_value.value == 2 + assert s.max_value.value == 3 + def test_vector(): base = 10 stride = 3 lanes = 2 s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes)) - assert s.min().value == base - assert s.max().value == base + stride * lanes - 1 - -def test_deduce(): - a = tvm.var('a') - b = tvm.var('b') - c = tvm.var('c') - d = tvm.var('d') - - b_s = tvm.arith.intset_interval(2, 3) - c_s = tvm.arith.intset_interval(10, 15) - d_s = tvm.arith.intset_interval(-3, -1) - zero = tvm.const(0, "int32") - - e0 = (-b)*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((d - c) /(b*-1)) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - # expression containing variable a is on rhs - res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - e0 = d*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((0-c)/d + 1) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - # expression containing variable a is on rhs - res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - e1 = (a*4+b < c) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (((c - b) + -1)/4) - assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) - - # expression containing variable a is on rhs - e1 = (c > a*4+b) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) - - e2 = (tvm.max(5, a * 4) < 0) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max()) == "neg_inf" - assert str(res2.min()) == "pos_inf" - - # expression containing variable a is on rhs - e2 = (zero < tvm.max(5, a * 4)) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max()) == "neg_inf" - assert str(res2.min()) == "pos_inf" - - - e3 = (-b)+a*c-d - res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - ans3 = 2/c+1 - assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) - - res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) - -def test_check(): - a = tvm.var('a') - b = tvm.var('b') - c = tvm.var('c') - d = tvm.var('d') - - b_s = tvm.arith.intset_interval(2, 3) - c_s = tvm.arith.intset_interval(5, 7) - d_s = tvm.arith.intset_interval(-3, -1) - - # no compare operator - res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) - assert res1.is_nothing() - - # multiple compare operators - res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) - assert res2.is_nothing() - - # multiple target variable - res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) - assert res2.is_nothing() - -def test_deduce_basic(): - def test_basic(a1, a2, coff): - a = tvm.var('a') - b = tvm.var('b') - b_s = tvm.arith.intset_interval(a1, a2) - e0 = b + a*coff + 3 - - res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 - - res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 - - test_basic(0, 4, 4) - test_basic(1, 5, 4) - test_basic(2, 6, 4) - test_basic(0, 4, -4) - test_basic(1, 5, -4) - test_basic(2, 6, -4) - -def test_deduce_complex(): - def test_complex(a1, a2, coff): - a = tvm.var('a') - b = tvm.var('b') - b_s = tvm.arith.intset_interval(a1, a2) - e0 = (b*3 + a* coff) * 4 - - res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 - - res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 - - test_complex(0, 4, 4) - test_complex(0, 4, -4) - test_complex(2, 6, 4) - test_complex(0, 4, -4) - test_complex(1, 5, -4) - test_complex(2, 6, -4) + assert s.min_value.value == base + assert s.max_value.value == base + stride * lanes - 1 + + +def test_add_sub(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y)) + ck.verify(x + y, + {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)}, + (1, 21)) + ck.verify(x - y, + {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)}, + (-11, 9)) + +def test_mul_div(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) + ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) + ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20)) + ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2)) + ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y)) + ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5)) + + +def test_mod(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) + ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) + ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9)) + +def test_max_min(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.verify(tvm.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11)) + ck.verify(tvm.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9)) + ck.verify(tvm.min(x, y), {}, (tvm.min(x, y), tvm.min(x, y))) + ck.verify(tvm.max(x, y), {}, (tvm.max(x, y), tvm.max(x, y))) + + +def test_select(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1), + {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11)) + if __name__ == "__main__": test_basic() test_vector() - test_deduce() - test_check() - test_deduce_basic() - test_deduce_complex() + test_add_sub() + test_mul_div() + test_max_min() + test_select() + test_mod() + From ebff9ac67d9a9c56d91429d41ffe3cab9729a7b1 Mon Sep 17 00:00:00 2001 From: Hua Date: Thu, 13 Jun 2019 15:01:42 -0700 Subject: [PATCH 133/176] [Relay] tflite frontend, keep underline with comments in same length. (#3363) --- tests/python/frontend/tflite/test_forward.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 549855f0cfb5..15357d47989c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -340,7 +340,7 @@ def _test_add(data): ####################################################################### # Subtract -# --- +# -------- def _test_sub(data): """ One iteration of subtract """ @@ -354,28 +354,28 @@ def _test_mul(data): ####################################################################### # Divide -# --- +# ------ def _test_div(data): """ One iteration of divide """ return _test_elemwise(math_ops.divide, data) ####################################################################### # Power -# --- +# ----- def _test_pow(data): """ One iteration of power """ return _test_elemwise(math_ops.pow, data) ####################################################################### # Maximum -# --- +# ------- def _test_maximum(data): """ One iteration of maximum """ return _test_elemwise(math_ops.maximum, data) ####################################################################### # Minimum -# --- +# ------- def _test_minimum(data): """ One iteration of minimum """ From c2c40b777e324830c2b5629adabd7c895c0376c7 Mon Sep 17 00:00:00 2001 From: Marcelo Duarte Trevisani Date: Thu, 13 Jun 2019 23:03:49 +0100 Subject: [PATCH 134/176] Update CMakeLists.txt to be more flexible (#3354) --- CMakeLists.txt | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f8bbbe24568..d064c959f17b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,12 @@ tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) +# 3rdparty libraries +tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") +tvm_option(DMLC_PATH "Path to DMLC" "3rdparty/dmlc-core/include") +tvm_option(RANG_PATH "Path to RANG" "3rdparty/rang/include") +tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") + # Contrib library options tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none) @@ -52,11 +58,12 @@ tvm_option(USE_TENSORRT "Build with TensorRT, must have CUDA and CUDNN enabled" tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) # include directories +include_directories(${CMAKE_INCLUDE_PATH}) include_directories("include") -include_directories("3rdparty/dlpack/include") -include_directories("3rdparty/dmlc-core/include") -include_directories("3rdparty/rang/include") -include_directories("3rdparty/compiler-rt") +include_directories(${DLPACK_PATH}) +include_directories(${DMLC_PATH}) +include_directories(${RANG_PATH}) +include_directories(${COMPILER_RT_PATH}) # initial variables set(TVM_LINKER_LIBS "") From faff49668eee42a77e6a14787501e160a6b7bc40 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Thu, 13 Jun 2019 15:08:40 -0700 Subject: [PATCH 135/176] [VTA] add support to event counters (#3347) * add support to event counters in VTA * fix comment * fix event-counter interface parameter * no longer needed * add sim back * add docs to event counters * fix docs * add more details about event counting * make dpi-module docs more accurate --- vta/apps/tsim_example/src/driver.cc | 7 +- .../chisel/src/main/scala/core/Core.scala | 6 + .../src/main/scala/core/EventCounters.scala | 56 +++++++ .../chisel/src/main/scala/shell/VCR.scala | 151 ++++++------------ vta/include/vta/dpi/module.h | 4 +- vta/python/vta/testing/simulator.py | 9 ++ vta/src/tsim/tsim_driver.cc | 42 +++-- vta/tests/python/unittest/test_vta_insn.py | 50 ++++-- 8 files changed, 190 insertions(+), 135 deletions(-) create mode 100644 vta/hardware/chisel/src/main/scala/core/EventCounters.scala diff --git a/vta/apps/tsim_example/src/driver.cc b/vta/apps/tsim_example/src/driver.cc index ad9d6ddf2620..c1dc61f8bee1 100644 --- a/vta/apps/tsim_example/src/driver.cc +++ b/vta/apps/tsim_example/src/driver.cc @@ -54,23 +54,20 @@ class Device { private: void Launch(uint32_t c, uint32_t length, void* inp, void* out) { dpi_->Launch(wait_cycles_); - // set counter to zero - dpi_->WriteReg(0x04, 0); dpi_->WriteReg(0x08, c); dpi_->WriteReg(0x0c, length); dpi_->WriteReg(0x10, get_half_addr(inp, false)); dpi_->WriteReg(0x14, get_half_addr(inp, true)); dpi_->WriteReg(0x18, get_half_addr(out, false)); dpi_->WriteReg(0x1c, get_half_addr(out, true)); - // launch - dpi_->WriteReg(0x00, 0x1); + dpi_->WriteReg(0x00, 0x1); // launch } uint32_t WaitForCompletion() { uint32_t i, val; for (i = 0; i < wait_cycles_; i++) { val = dpi_->ReadReg(0x00); - if (val == 2) break; // finish + if (val == 2) break; // finish } val = dpi_->ReadReg(0x04); return val; diff --git a/vta/hardware/chisel/src/main/scala/core/Core.scala b/vta/hardware/chisel/src/main/scala/core/Core.scala index 2a2d4e02784f..6c29a88548a7 100644 --- a/vta/hardware/chisel/src/main/scala/core/Core.scala +++ b/vta/hardware/chisel/src/main/scala/core/Core.scala @@ -64,6 +64,7 @@ class Core(implicit p: Parameters) extends Module { val load = Module(new Load) val compute = Module(new Compute) val store = Module(new Store) + val ecounters = Module(new EventCounters) // Read(rd) and write(wr) from/to memory (i.e. DRAM) io.vme.rd(0) <> fetch.io.vme_rd @@ -103,6 +104,11 @@ class Core(implicit p: Parameters) extends Module { store.io.out_baddr := io.vcr.ptrs(5) store.io.out <> compute.io.out + // Event counters + ecounters.io.launch := io.vcr.launch + ecounters.io.finish := compute.io.finish + io.vcr.ecnt <> ecounters.io.ecnt + // Finish instruction is executed and asserts the VCR finish flag val finish = RegNext(compute.io.finish) io.vcr.finish := finish diff --git a/vta/hardware/chisel/src/main/scala/core/EventCounters.scala b/vta/hardware/chisel/src/main/scala/core/EventCounters.scala new file mode 100644 index 000000000000..5a5b095aa332 --- /dev/null +++ b/vta/hardware/chisel/src/main/scala/core/EventCounters.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vta.core + +import chisel3._ +import chisel3.util._ +import vta.util.config._ +import vta.shell._ + +/** EventCounters. + * + * This unit contains all the event counting logic. One common event tracked in + * hardware is the number of clock cycles taken to achieve certain task. We + * can count the total number of clock cycles spent in a VTA run by checking + * launch and finish signals. + * + * The event counter value is passed to the VCR module via the ecnt port, so + * they can be accessed by the host. The number of event counters (nECnt) is + * defined in the Shell VCR module as a parameter, see VCRParams. + * + * If one would like to add an event counter, then the value of nECnt must be + * changed in VCRParams together with the corresponding counting logic here. + */ +class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module { + val vp = p(ShellKey).vcrParams + val io = IO(new Bundle{ + val launch = Input(Bool()) + val finish = Input(Bool()) + val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W))) + }) + val cycle_cnt = RegInit(0.U(vp.regBits.W)) + when (io.launch && !io.finish) { + cycle_cnt := cycle_cnt + 1.U + } .otherwise { + cycle_cnt := 0.U + } + io.ecnt(0).valid := io.finish + io.ecnt(0).bits := cycle_cnt +} diff --git a/vta/hardware/chisel/src/main/scala/shell/VCR.scala b/vta/hardware/chisel/src/main/scala/shell/VCR.scala index 463f55bc8bbd..0f13cfe43cd3 100644 --- a/vta/hardware/chisel/src/main/scala/shell/VCR.scala +++ b/vta/hardware/chisel/src/main/scala/shell/VCR.scala @@ -23,8 +23,6 @@ import chisel3._ import chisel3.util._ import vta.util.config._ import vta.util.genericbundle._ -import scala.collection.mutable.ListBuffer -import scala.collection.mutable.LinkedHashMap import vta.interface.axi._ /** VCR parameters. @@ -33,14 +31,11 @@ import vta.interface.axi._ */ case class VCRParams() { - val nValsReg: Int = 1 - val nPtrsReg: Int = 6 - val regBits: Int = 32 - val nCtrlReg: Int = 4 - val ctrlBaseAddr: Int = 0 - - require (nValsReg > 0) - require (nPtrsReg > 0) + val nCtrl = 1 + val nECnt = 1 + val nVals = 1 + val nPtrs = 6 + val regBits = 32 } /** VCRBase. Parametrize base class. */ @@ -57,9 +52,9 @@ class VCRMaster(implicit p: Parameters) extends VCRBase { val mp = p(ShellKey).memParams val launch = Output(Bool()) val finish = Input(Bool()) - val irq = Output(Bool()) - val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) - val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W))) + val ecnt = Vec(vp.nECnt, Flipped(ValidIO(UInt(vp.regBits.W)))) + val vals = Output(Vec(vp.nVals, UInt(vp.regBits.W))) + val ptrs = Output(Vec(vp.nPtrs, UInt(mp.addrBits.W))) } /** VCRClient. @@ -72,9 +67,9 @@ class VCRClient(implicit p: Parameters) extends VCRBase { val mp = p(ShellKey).memParams val launch = Input(Bool()) val finish = Output(Bool()) - val irq = Input(Bool()) - val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) - val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W))) + val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W))) + val vals = Input(Vec(vp.nVals, UInt(vp.regBits.W))) + val ptrs = Input(Vec(vp.nPtrs, UInt(mp.addrBits.W))) } /** VTA Control Registers (VCR). @@ -97,10 +92,23 @@ class VCR(implicit p: Parameters) extends Module { // Write control (AW, W, B) val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address val wdata = io.host.w.bits.data - val wstrb = io.host.w.bits.strb - val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0))) val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3) val wstate = RegInit(sWriteAddress) + + // read control (AR, R) + val sReadAddress :: sReadData :: Nil = Enum(2) + val rstate = RegInit(sReadAddress) + val rdata = RegInit(0.U(vp.regBits.W)) + + // registers + val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + (2*vp.nPtrs) + val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W))) + val addr = Seq.tabulate(nTotal)(_ * 4) + val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } + val eo = vp.nCtrl + val vo = eo + vp.nECnt + val po = vo + vp.nVals + switch (wstate) { is (sWriteAddress) { when (io.host.aw.valid) { @@ -124,11 +132,8 @@ class VCR(implicit p: Parameters) extends Module { io.host.aw.ready := wstate === sWriteAddress io.host.w.ready := wstate === sWriteData io.host.b.valid := wstate === sWriteResponse - io.host.b.bits.resp := "h_0".U + io.host.b.bits.resp := 0.U - // read control (AR, R) - val sReadAddress :: sReadData :: Nil = Enum(2) - val rstate = RegInit(sReadAddress) switch (rstate) { is (sReadAddress) { @@ -145,98 +150,40 @@ class VCR(implicit p: Parameters) extends Module { io.host.ar.ready := rstate === sReadAddress io.host.r.valid := rstate === sReadData + io.host.r.bits.data := rdata + io.host.r.bits.resp := 0.U - val nPtrsReg = vp.nPtrsReg - val nValsReg = vp.nValsReg - val regBits = vp.regBits - val ptrsBits = mp.addrBits - val nCtrlReg = vp.nCtrlReg - val rStride = regBits/8 - val pStride = ptrsBits/8 - val ctrlBaseAddr = vp.ctrlBaseAddr - val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride - val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride - - val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr) - val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr) - - val ptrsAddr = new ListBuffer[Int]() - for (i <- 0 until nPtrsReg) { - ptrsAddr += i*pStride + ptrsBaseAddr - if (ptrsBits == 64) { - ptrsAddr += i*pStride + rStride + ptrsBaseAddr - } - } - - // AP register - val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B))) - - // ap start - when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) { - c0(0) := true.B - } .elsewhen (io.vcr.finish) { - c0(0) := false.B - } - - // ap done = finish when (io.vcr.finish) { - c0(1) := true.B - } .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) { - c0(1) := false.B + reg(0) := "b_10".U + } .elsewhen (io.host.w.fire() && addr(0).U === waddr) { + reg(0) := wdata } - val c1 = 0.U - val c2 = 0.U - val c3 = 0.U - - val ctrlRegList = List(c0, c1, c2, c3) - - io.vcr.launch := c0(0) - - // interrupts not supported atm - io.vcr.irq := false.B - - // Write pointer and value registers - val pvAddr = valsAddr ++ ptrsAddr - val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg - val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W)))) - val pvRegList = new ListBuffer[UInt]() - - for (i <- 0 until pvNumReg) { - when (io.host.w.fire() && (waddr === pvAddr(i).U)) { - pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask) + for (i <- 0 until vp.nECnt) { + when (io.vcr.ecnt(i).valid) { + reg(eo + i) := io.vcr.ecnt(i).bits + } .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) { + reg(eo + i) := wdata } - pvRegList += pvReg(i) - } - - for (i <- 0 until nValsReg) { - io.vcr.vals(i) := pvReg(i) } - for (i <- 0 until nPtrsReg) { - if (ptrsBits == 64) { - io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2)) - } else { - io.vcr.ptrs(i) := pvReg(nValsReg + i) + for (i <- 0 until (vp.nVals + (2*vp.nPtrs))) { + when (io.host.w.fire() && addr(vo + i).U === waddr) { + reg(vo + i) := wdata } } - // Read pointer and value registers - val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr - val mapRegList = ctrlRegList ++ pvRegList - - val rdata = RegInit(0.U(regBits.W)) - val rmap = LinkedHashMap[Int,UInt]() - - val totalReg = mapRegList.length - for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt } + when (io.host.ar.fire()) { + rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map) + } - val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) } + io.vcr.launch := reg(0)(0) - when (io.host.ar.fire()) { - rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v) + for (i <- 0 until vp.nVals) { + io.vcr.vals(i) := reg(vo + i) } - io.host.r.bits.resp := 0.U - io.host.r.bits.data := rdata + for (i <- 0 until vp.nPtrs) { + io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) + } } diff --git a/vta/include/vta/dpi/module.h b/vta/include/vta/dpi/module.h index d2e4c80129eb..c83dad1b3299 100644 --- a/vta/include/vta/dpi/module.h +++ b/vta/include/vta/dpi/module.h @@ -35,7 +35,7 @@ namespace dpi { class DPIModuleNode : public tvm::runtime::ModuleNode { public: /*! - * \brief Launch accelerator until it finishes or reach max_cycles + * \brief Launch hardware simulation until accelerator finishes or reach max_cycles * \param max_cycles The maximum of cycles to wait */ virtual void Launch(uint64_t max_cycles) = 0; @@ -53,7 +53,7 @@ class DPIModuleNode : public tvm::runtime::ModuleNode { */ virtual uint32_t ReadReg(int addr) = 0; -/*! \brief Kill or Exit() the accelerator */ +/*! \brief Finish hardware simulation */ virtual void Finish() = 0; static tvm::runtime::Module Load(std::string dll_name); diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py index 858e1157d8b2..dbeba84f6d4a 100644 --- a/vta/python/vta/testing/simulator.py +++ b/vta/python/vta/testing/simulator.py @@ -74,5 +74,14 @@ def tsim_init(hw_lib): m = tvm.module.load(lib, "vta-tsim") f(m) +def tsim_cycles(): + """Get tsim clock cycles + + Returns + ------- + stats : int + tsim clock cycles + """ + return tvm.get_global_func("tvm.vta.tsim.cycles")() LIBS = _load_lib() diff --git a/vta/src/tsim/tsim_driver.cc b/vta/src/tsim/tsim_driver.cc index e0ceb9028503..6dd273c25168 100644 --- a/vta/src/tsim/tsim_driver.cc +++ b/vta/src/tsim/tsim_driver.cc @@ -28,6 +28,17 @@ namespace tsim { using vta::dpi::DPIModuleNode; using tvm::runtime::Module; +class Profiler { + public: + /*! \brief cycle counter */ + uint64_t cycle_count{0}; + + static Profiler* Global() { + static Profiler inst; + return &inst; + } +}; + class DPILoader { public: void Init(Module module) { @@ -50,6 +61,7 @@ class Device { public: Device() { dpi_ = DPILoader::Global(); + prof_ = Profiler::Global(); } int Run(vta_phy_addr_t insn_phy_addr, @@ -89,19 +101,21 @@ class Device { uint32_t wait_cycles) { // launch simulation thread dev_->Launch(wait_cycles); - dev_->WriteReg(0x10, insn_count); - dev_->WriteReg(0x14, insn_phy_addr); - dev_->WriteReg(0x18, insn_phy_addr >> 32); + // set counter to zero + dev_->WriteReg(0x04, 0); + dev_->WriteReg(0x08, insn_count); + dev_->WriteReg(0x0c, insn_phy_addr); + dev_->WriteReg(0x10, insn_phy_addr >> 32); + dev_->WriteReg(0x14, 0); + dev_->WriteReg(0x18, uop_phy_addr >> 32); dev_->WriteReg(0x1c, 0); - dev_->WriteReg(0x20, uop_phy_addr >> 32); + dev_->WriteReg(0x20, inp_phy_addr >> 32); dev_->WriteReg(0x24, 0); - dev_->WriteReg(0x28, inp_phy_addr >> 32); + dev_->WriteReg(0x28, wgt_phy_addr >> 32); dev_->WriteReg(0x2c, 0); - dev_->WriteReg(0x30, wgt_phy_addr >> 32); + dev_->WriteReg(0x30, acc_phy_addr >> 32); dev_->WriteReg(0x34, 0); - dev_->WriteReg(0x38, acc_phy_addr >> 32); - dev_->WriteReg(0x3c, 0); - dev_->WriteReg(0x40, out_phy_addr >> 32); + dev_->WriteReg(0x38, out_phy_addr >> 32); // start dev_->WriteReg(0x00, 0x1); } @@ -113,9 +127,14 @@ class Device { val &= 0x2; if (val == 0x2) break; // finish } + prof_->cycle_count = dev_->ReadReg(0x04); } + // Profiler + Profiler* prof_; + // DPI loader DPILoader* dpi_; + // DPI Module DPIModuleNode* dev_; }; @@ -128,6 +147,11 @@ TVM_REGISTER_GLOBAL("tvm.vta.tsim.init") DPILoader::Global()->Init(m); }); +TVM_REGISTER_GLOBAL("tvm.vta.tsim.cycles") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = static_cast(Profiler::Global()->cycle_count); + }); + } // namespace tsim } // namespace vta diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py index 2cedceae4e7d..815f55b5e595 100644 --- a/vta/tests/python/unittest/test_vta_insn.py +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -73,8 +73,12 @@ def _run(env, remote): simulator.tsim_init("libvta_hw") f(x_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + if env.TARGET == "tsim": + print("Load/store test took {} clock cycles".format(simulator.tsim_cycles())) + vta.testing.run(_run) @@ -135,8 +139,12 @@ def _run(env, remote): simulator.tsim_init("libvta_hw") f(x_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + if env.TARGET == "tsim": + print("Padded load test took {} clock cycles".format(simulator.tsim_cycles())) + vta.testing.run(_run) @@ -180,7 +188,7 @@ def _run(env, remote): if not remote: return - def verify(s): + def verify(s, name=None): mod = vta.build(s, [x, w, y], "ext_dev", env.target_host) temp = util.tempdir() mod.save(temp.relpath("gemm.o")) @@ -217,6 +225,9 @@ def verify(s): np.testing.assert_equal(y_np, y_nd.asnumpy()) + if env.TARGET == "tsim": + print("GEMM schedule:{} test took {} clock cycles".format(name, simulator.tsim_cycles())) + def test_schedule1(): # default schedule with no smt s = tvm.create_schedule(y.op) @@ -245,7 +256,7 @@ def test_schedule1(): s[y_gem].op.axis[3], ki) s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm) - verify(s) + verify(s, name="default") def test_smt(): # test smt schedule @@ -279,7 +290,7 @@ def test_smt(): s[w_buf].compute_at(s[y_gem], ko) s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) s[y].pragma(abo2, env.dma_copy) - verify(s) + verify(s, name="smt") test_schedule1() test_smt() @@ -288,7 +299,7 @@ def test_smt(): def test_alu(): def _run(env, remote): - def check_alu(tvm_op, np_op=None, use_imm=False): + def check_alu(tvm_op, np_op=None, use_imm=False, test_name=None): """Test ALU""" m = 8 n = 8 @@ -371,14 +382,18 @@ def check_alu(tvm_op, np_op=None, use_imm=False): else: b_nd = tvm.nd.array(b_np, ctx) f(a_nd, b_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) - check_alu(lambda x, y: x << y, np.left_shift, use_imm=True) - check_alu(tvm.max, np.maximum, use_imm=True) - check_alu(tvm.max, np.maximum) - check_alu(lambda x, y: x + y, use_imm=True) - check_alu(lambda x, y: x + y) - check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True) + if env.TARGET == "tsim": + print("ALU {} imm:{} test took {} clock cycles".format(test_name, use_imm, simulator.tsim_cycles())) + + check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL") + check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX") + check_alu(tvm.max, np.maximum, test_name="MAX") + check_alu(lambda x, y: x + y, use_imm=True, test_name="ADD") + check_alu(lambda x, y: x + y, test_name="ADD") + check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True, test_name="SHR") vta.testing.run(_run) @@ -440,8 +455,12 @@ def _run(env, remote): simulator.tsim_init("libvta_hw") f(a_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + if env.TARGET == "tsim": + print("Relu test took {} clock cycles".format(simulator.tsim_cycles())) + vta.testing.run(_run) @@ -503,8 +522,12 @@ def _run(env, remote): simulator.tsim_init("libvta_hw") f(a_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + if env.TARGET == "tsim": + print("Shift/scale test took {} clock cycles".format(simulator.tsim_cycles())) + vta.testing.run(_run) @@ -521,17 +544,10 @@ def _run(env, remote): if __name__ == "__main__": - print("Array test") test_runtime_array() - print("Load/store test") test_save_load_out() - print("Padded load test") test_padded_load() - print("GEMM test") test_gemm() - print("ALU test") test_alu() - print("Relu test") test_relu() - print("Shift and scale") test_shift_and_scale() From eb675ece0e1097e7c3efa6149f1ac4e3e40d287f Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 13 Jun 2019 17:48:17 -0700 Subject: [PATCH 136/176] [TEST][FLAKY] Fix flaky test on topk and quantize pass (#3362) * fix flaky test * fix flaky quantize pass --- tests/python/relay/test_op_level6.py | 6 +++--- tests/python/relay/test_pass_quantize.py | 3 ++- topi/tests/python/test_topi_sort.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 76478baf5a19..286776e3f7b2 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -80,12 +80,12 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): tvm.testing.assert_allclose(op_res.asnumpy(), np_values) else: tvm.testing.assert_allclose(op_res.asnumpy(), np_indices) + np.random.seed(0) for k in [0, 1, 5]: for axis in [0, -1, 1]: for ret_type in ["both", "values", "indices"]: - for dtype in ["int64", "float32"]: - verify_topk(k, axis, ret_type, False, dtype) - verify_topk(k, axis, ret_type, True, dtype) + verify_topk(k, axis, ret_type, True, "int64") + verify_topk(k, axis, ret_type, False, "float32") if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py index 1630efce7f6c..e02601e926f2 100644 --- a/tests/python/relay/test_pass_quantize.py +++ b/tests/python/relay/test_pass_quantize.py @@ -75,6 +75,8 @@ def make_qgraph(data, weight): out = relay.Function(relay.ir_pass.free_vars(out), out) return out + np.random.seed(42) + data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) graph = make_graph(data) dataset, params = make_dataset(graph, 10) @@ -95,6 +97,5 @@ def make_qgraph(data, weight): if __name__ == "__main__": - np.random.seed(42) test_simulated_quantize() test_quantize_pass() diff --git a/topi/tests/python/test_topi_sort.py b/topi/tests/python/test_topi_sort.py index ed902b982a2b..c084a7c431b6 100644 --- a/topi/tests/python/test_topi_sort.py +++ b/topi/tests/python/test_topi_sort.py @@ -96,12 +96,12 @@ def check_device(device): check_device(device) def test_topk(): + np.random.seed(0) for k in [0, 1, 5]: for axis in [0, -1, 1]: for ret_type in ["both", "values", "indices"]: - for dtype in ["int64", "float32"]: - verify_topk(k, axis, ret_type, True, dtype) - verify_topk(k, axis, ret_type, False, dtype) + verify_topk(k, axis, ret_type, True, "int64") + verify_topk(k, axis, ret_type, False, "float32") if __name__ == "__main__": From 4a0b742a8c2d2d89d9b850113bb6c15e95e8b7bb Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Fri, 14 Jun 2019 01:01:00 -0700 Subject: [PATCH 137/176] fix hardware-makefile for osx, bugfix chisel-RegFile, and rename driver (#3371) --- vta/apps/tsim_example/Makefile | 10 +++------- vta/apps/tsim_example/hardware/chisel/Makefile | 12 +++++++++--- .../chisel/src/main/scala/accel/RegFile.scala | 2 +- vta/apps/tsim_example/hardware/verilog/Makefile | 11 +++++++++-- vta/apps/tsim_example/python/__init__.py | 1 + vta/apps/tsim_example/python/accel/__init__.py | 0 .../tsim_example/python/{accel/driver.py => tsim.py} | 4 ++-- vta/apps/tsim_example/tests/python/chisel_accel.py | 5 ++--- vta/apps/tsim_example/tests/python/verilog_accel.py | 5 ++--- vta/hardware/chisel/Makefile | 11 +++++++++-- 10 files changed, 38 insertions(+), 23 deletions(-) create mode 100644 vta/apps/tsim_example/python/__init__.py delete mode 100644 vta/apps/tsim_example/python/accel/__init__.py rename vta/apps/tsim_example/python/{accel/driver.py => tsim.py} (90%) diff --git a/vta/apps/tsim_example/Makefile b/vta/apps/tsim_example/Makefile index ea8358b3dfe3..b18ced840d15 100644 --- a/vta/apps/tsim_example/Makefile +++ b/vta/apps/tsim_example/Makefile @@ -20,12 +20,11 @@ export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH) BUILD_NAME = build build_dir = $(abspath .)/$(BUILD_NAME) -default: verilog driver run_verilog +default: verilog driver + python3 tests/python/verilog_accel.py -run_chisel: chisel driver +run_chisel: chisel driver python3 tests/python/chisel_accel.py - -.PHONY: cmake driver: | $(build_dir) cd $(build_dir) && cmake .. && make @@ -39,9 +38,6 @@ verilog: chisel: make -C hardware/chisel -run_verilog: - python3 tests/python/verilog_accel.py - clean: -rm -rf $(build_dir) make -C hardware/chisel clean diff --git a/vta/apps/tsim_example/hardware/chisel/Makefile b/vta/apps/tsim_example/hardware/chisel/Makefile index 463786a9a806..4f555bab6dc3 100644 --- a/vta/apps/tsim_example/hardware/chisel/Makefile +++ b/vta/apps/tsim_example/hardware/chisel/Makefile @@ -84,11 +84,17 @@ else cxx_flags += -DVM_TRACE=0 endif +# The following is to be consistent with cmake +ifeq ($(shell uname), Darwin) + lib_path = $(build_dir)/$(LIBNAME).dylib +else + lib_path = $(build_dir)/$(LIBNAME).so +endif + default: lib -lib: $(build_dir)/$(LIBNAME).so -$(build_dir)/$(LIBNAME).so: $(verilator_build_dir)/V$(TOP).cpp - echo $(cxx_files) +lib: $(lib_path) +$(lib_path): $(verilator_build_dir)/V$(TOP).cpp g++ $(cxx_flags) $(cxx_files) -o $@ verilator: $(verilator_build_dir)/V$(TOP).cpp diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala index 5fdb3529573c..92a9833ffaa3 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala @@ -53,7 +53,7 @@ class RegFile(implicit config: AccelConfig) extends Module { val finish = Input(Bool()) val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W)))) val vals = Output(Vec(config.nVals, UInt(config.regBits.W))) - val ptrs = Output(Vec(config.nPtrs, UInt(config.regBits.W))) + val ptrs = Output(Vec(config.nPtrs, UInt(config.ptrBits.W))) val host = new VTAHostDPIClient }) val sIdle :: sRead :: Nil = Enum(2) diff --git a/vta/apps/tsim_example/hardware/verilog/Makefile b/vta/apps/tsim_example/hardware/verilog/Makefile index 8a4369aa8075..9617a07ad565 100644 --- a/vta/apps/tsim_example/hardware/verilog/Makefile +++ b/vta/apps/tsim_example/hardware/verilog/Makefile @@ -83,10 +83,17 @@ else cxx_flags += -DVM_TRACE=0 endif +# The following is to be consistent with cmake +ifeq ($(shell uname), Darwin) + lib_path = $(build_dir)/$(LIBNAME).dylib +else + lib_path = $(build_dir)/$(LIBNAME).so +endif + default: lib -lib: $(build_dir)/$(LIBNAME).so -$(build_dir)/$(LIBNAME).so: $(build_dir)/V$(TOP).cpp +lib: $(lib_path) +$(lib_path): $(build_dir)/V$(TOP).cpp g++ $(cxx_flags) $(cxx_files) -o $@ verilator: $(build_dir)/V$(TOP).cpp diff --git a/vta/apps/tsim_example/python/__init__.py b/vta/apps/tsim_example/python/__init__.py new file mode 100644 index 000000000000..784036f7d0ae --- /dev/null +++ b/vta/apps/tsim_example/python/__init__.py @@ -0,0 +1 @@ +from . import tsim diff --git a/vta/apps/tsim_example/python/accel/__init__.py b/vta/apps/tsim_example/python/accel/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vta/apps/tsim_example/python/accel/driver.py b/vta/apps/tsim_example/python/tsim.py similarity index 90% rename from vta/apps/tsim_example/python/accel/driver.py rename to vta/apps/tsim_example/python/tsim.py index 6d8e7181b707..a41d904ab006 100644 --- a/vta/apps/tsim_example/python/accel/driver.py +++ b/vta/apps/tsim_example/python/tsim.py @@ -34,8 +34,8 @@ def driver(hw_backend): _sw_libname = "libsw" + _ext _cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) if hw_backend in ("verilog", "chisel"): - _hw_lib = osp.join(_cur_path, "..", "..", "hardware", hw_backend, "build", _hw_libname) - _sw_lib = osp.join(_cur_path, "..", "..", "build", _sw_libname) + _hw_lib = osp.join(_cur_path, "..", "hardware", hw_backend, "build", _hw_libname) + _sw_lib = osp.join(_cur_path, "..", "build", _sw_libname) def load_dll(dll): try: diff --git a/vta/apps/tsim_example/tests/python/chisel_accel.py b/vta/apps/tsim_example/tests/python/chisel_accel.py index 6ab0bf5a36eb..26565c3d78eb 100644 --- a/vta/apps/tsim_example/tests/python/chisel_accel.py +++ b/vta/apps/tsim_example/tests/python/chisel_accel.py @@ -17,8 +17,7 @@ import tvm import numpy as np - -from accel.driver import driver +import tsim def test_accel(): rmax = 64 @@ -27,7 +26,7 @@ def test_accel(): ctx = tvm.cpu(0) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) - f = driver("chisel") + f = tsim.driver("chisel") cycles = f(a, b, c) msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg) diff --git a/vta/apps/tsim_example/tests/python/verilog_accel.py b/vta/apps/tsim_example/tests/python/verilog_accel.py index 97f636cbfde1..d88964b1ed5c 100644 --- a/vta/apps/tsim_example/tests/python/verilog_accel.py +++ b/vta/apps/tsim_example/tests/python/verilog_accel.py @@ -17,8 +17,7 @@ import tvm import numpy as np - -from accel.driver import driver +import tsim def test_accel(): rmax = 64 @@ -27,7 +26,7 @@ def test_accel(): ctx = tvm.cpu(0) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) - f = driver("verilog") + f = tsim.driver("verilog") cycles = f(a, b, c) msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg) diff --git a/vta/hardware/chisel/Makefile b/vta/hardware/chisel/Makefile index 7e90168c21c6..3c9b60148017 100644 --- a/vta/hardware/chisel/Makefile +++ b/vta/hardware/chisel/Makefile @@ -86,10 +86,17 @@ else cxx_flags += -DVM_TRACE=0 endif +# The following is to be consistent with cmake +ifeq ($(shell uname), Darwin) + lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).dylib +else + lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so +endif + default: lib -lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so -$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp +lib: $(lib_path) +$(lib_path): $(verilator_build_dir)/V$(TOP_TEST).cpp g++ $(cxx_flags) $(cxx_files) -o $@ verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp From fcc6897016b818cd1603613fe64fba8706599156 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 14 Jun 2019 10:30:46 -0700 Subject: [PATCH 138/176] [BUILD] Enable more visible symbols by default (#3365) --- CMakeLists.txt | 10 ++++++++-- Jenkinsfile | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d064c959f17b..a140c597a89f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,7 @@ tvm_option(USE_SGX "Build with SGX" OFF) tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) +tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -97,8 +98,13 @@ else(MSVC) set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS} -rdynamic") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS} -rdynamic") else() - set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11 ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}") + if (HIDE_PRIVATE_SYMBOLS) + message("Hide private symbols...") + set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "-fvisibility=hidden ${CMAKE_CXX_FLAGS}") + endif(HIDE_PRIVATE_SYMBOLS) endif () if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) diff --git a/Jenkinsfile b/Jenkinsfile index bdbb3ecb6427..c50297402ddb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -183,6 +183,7 @@ stage('Build') { echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake + echo set\\(HIDE_PRIVATE_SYMBOLS ON\\) >> config.cmake """ make(ci_cpu, 'build', '-j4') pack_lib('cpu', tvm_lib) From c45bef423a410c55e6da1054bf859f8f9e8d03c0 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 14 Jun 2019 13:34:17 -0700 Subject: [PATCH 139/176] Add test_forward_ssd_mobilenet_v1 to tflite/test_forward (#3350) --- python/tvm/relay/testing/tf.py | 9 +++------ tests/python/frontend/tflite/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index d82ed0f46097..a56e6fe1782d 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -163,13 +163,10 @@ def get_workload_official(model_url, model_sub_path): model_sub_path: Sub path in extracted tar for the ftozen protobuf file. - temp_dir: TempDirectory - The temporary directory object to download the content. - Returns ------- - graph_def: graphdef - graph_def is the tensorflow workload for mobilenet. + model_path: str + Full path to saved model file """ @@ -200,7 +197,7 @@ def get_workload(model_path, model_sub_path=None): Returns ------- graph_def: graphdef - graph_def is the tensorflow workload for mobilenet. + graph_def is the tensorflow workload. """ diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 15357d47989c..ec345ee78961 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -598,6 +598,24 @@ def test_forward_inception_v4_net(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) +####################################################################### +# SSD Mobilenet +# ------------- + +def test_forward_ssd_mobilenet_v1(): + """Test the SSD Mobilenet V1 TF Lite model.""" + # SSD MobilenetV1 + tflite_model_file = tf_testing.get_workload_official( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz", + "ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + ####################################################################### # Main # ---- @@ -623,3 +641,4 @@ def test_forward_inception_v4_net(): test_forward_mobilenet_v2() test_forward_inception_v3_net() test_forward_inception_v4_net() + test_forward_ssd_mobilenet_v1() From 82a55464304889c91dd0670c6e21adc3e883703a Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 14 Jun 2019 15:18:14 -0700 Subject: [PATCH 140/176] [Relay][VM] Add AllocTensor instruction and better instruction printer (#3306) * Update vm print & add AllocTensor instruction * patch * fix invoke packed * update cmake * tweak move * update invoke_closure * lint * add doc * tweak --- CMakeLists.txt | 1 + include/tvm/runtime/vm.h | 36 ++++-- src/relay/backend/vm/compiler.cc | 75 ++---------- src/runtime/vm/vm.cc | 202 ++++++++++++++++++++----------- 4 files changed, 170 insertions(+), 144 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a140c597a89f..80b121477631 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,7 @@ add_library(tvm_runtime_static STATIC ${RUNTIME_SRCS}) if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") +else() set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") endif(USE_RELAY_DEBUG) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 8911ad499e4c..028a5ff9d1ad 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -56,13 +56,14 @@ enum class Opcode { InvokeClosure = 3U, InvokePacked = 4U, AllocTensor = 5U, - AllocDatatype = 6U, - AllocClosure = 7U, - GetField = 8U, - If = 9U, - Select = 10U, - LoadConst = 11U, - Goto = 12U + AllocTensorReg = 6U, + AllocDatatype = 7U, + AllocClosure = 8U, + GetField = 9U, + If = 10U, + Select = 11U, + LoadConst = 12U, + Goto = 13U }; /*! \brief A single virtual machine instruction. @@ -83,11 +84,19 @@ struct Instruction { union { struct /* AllocTensor Operands */ { + /*! \brief The number of dimensions. */ + uint32_t ndim; + /*! \brief The shape of tensor. */ + int64_t* shape; + /*! \brief The datatype of tensor to be allocated. */ + DLDataType dtype; + } alloc_tensor; + struct /* AllocTensorReg Operands */ { /*! \brief The register to read the shape out of. */ RegName shape_register; /*! \brief The datatype of tensor to be allocated. */ DLDataType dtype; - }; + } alloc_tensor_reg; struct /* InvokeClosure Operands */ { /*! \brief The register containing the closure. */ RegName closure; @@ -192,13 +201,20 @@ struct Instruction { */ static Instruction InvokePacked(Index packed_index, Index arity, Index output_size, const std::vector& args); - /*! \brief Construct an allocate tensor instruction. + /*! \brief Construct an allocate tensor instruction with constant shape. + * \param shape The shape of the tensor. + * \param dtype The dtype of the tensor. + * \param dst The destination register. + * \return The allocate tensor instruction. + */ + static Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst); + /*! \brief Construct an allocate tensor instruction with register. * \param shape_register The register containing the shape. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName shape_register, DLDataType dtype, RegName dst); + static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst); /*! \brief Construct an allocate datatype instruction. * \param tag The datatype tag. * \param num_fields The number of fields for the datatype. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 9b4ab6b8f6c8..3e41ce717e71 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -103,13 +103,6 @@ struct ConstantPool : ExprVisitor { } } - void AddConstantTensorShape(TensorType expr, NDArray value) { - auto it = this->const_tensor_shape_map.find(expr); - if (it == this->const_tensor_shape_map.end()) { - this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)}); - } - } - void VisitExpr_(const ConstantNode* const_node) { auto konst = GetRef(const_node); auto it = this->const_map.find(konst); @@ -117,48 +110,6 @@ struct ConstantPool : ExprVisitor { this->const_map.insert({konst, index++}); } } - - NDArray GetTensorConstant(const TensorTypeNode* ttype) { - std::vector shapes; - for (auto sh : ttype->shape) { - shapes.push_back(Downcast(sh)->value); - } - int64_t s = shapes.size(); - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx); - int64_t* dims = static_cast(shape_tensor->data); - for (size_t i = 0; i < shapes.size(); ++i) { - dims[i] = shapes[i]; - } - return shape_tensor; - } - - void VisitExpr_(const CallNode* call_node) { - for (auto arg : call_node->args) { - this->VisitExpr(arg); - } - - Expr op = call_node->op; - auto func_node = op.as(); - if (func_node) { - auto ret_type = call_node->checked_type(); - if (const TensorTypeNode* ttype = ret_type.as()) { - auto shape = GetTensorConstant(ttype); - auto tensor_type = GetRef(ttype); - AddConstantTensorShape(tensor_type, shape); - } else if (const TupleTypeNode* ttype = ret_type.as()) { - for (size_t i = 0; i < ttype->fields.size(); ++i) { - auto f = ttype->fields[i]; - auto f_type = f.as(); - auto shape = GetTensorConstant(f_type); - auto tensor_type = GetRef(f_type); - AddConstantTensorShape(tensor_type, shape); - } - } - } - } }; std::tuple LayoutConstantPool(const Module& module) { @@ -206,6 +157,7 @@ struct VMCompiler : ExprFunctor { switch (instr.op) { case Opcode::AllocDatatype: case Opcode::AllocTensor: + case Opcode::AllocTensorReg: case Opcode::GetField: case Opcode::LoadConst: case Opcode::Select: @@ -259,14 +211,14 @@ struct VMCompiler : ExprFunctor { void VisitExpr_(const MatchNode* match_node) { auto match = GetRef(match_node); - LOG(FATAL) << "translation of match nodes to the VM is " - << "currently unsupported" << std::endl; + LOG(FATAL) << "translation of match nodes to the VM is" + << "currently unsupported"; } void VisitExpr_(const LetNode* let_node) { - DLOG(INFO) << let_node->value << std::endl; + DLOG(INFO) << let_node->value; this->VisitExpr(let_node->value); - DLOG(INFO) << this->last_register << std::endl; + DLOG(INFO) << this->last_register; var_register_map.insert({let_node->var, this->last_register}); this->VisitExpr(let_node->body); } @@ -327,18 +279,13 @@ struct VMCompiler : ExprFunctor { } Instruction AllocTensorFromType(const TensorTypeNode* ttype) { - DataType dtype = ttype->dtype; - TVMType dltype = Type2TVMType(dtype); - + TVMType dltype = Type2TVMType(ttype->dtype); auto tensor_type = GetRef(ttype); - auto it = this->context->const_tensor_shape_map.find(tensor_type); - if (it == this->context->const_tensor_shape_map.end()) { - DLOG(INFO) << "Can not find constant shape for " << tensor_type; - } else { - Emit(Instruction::LoadConst(it->second.first, NewRegister())); + std::vector shape; + for (auto dim : tensor_type->shape) { + shape.push_back(Downcast(dim)->value); } - - return Instruction::AllocTensor(last_register, dltype, NewRegister()); + return Instruction::AllocTensor(shape, dltype, NewRegister()); } void EmitInvokePrimitive(const Function& func, @@ -532,7 +479,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, } VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { - DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl; + DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false); size_t params = func->params.size(); VMCompiler compiler(context); compiler.Compile(func); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 6f9190e8907a..5ba20982e90f 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -67,8 +67,14 @@ Instruction::Instruction(const Instruction& instr) { this->result = instr.result; return; case Opcode::AllocTensor: - this->shape_register = instr.shape_register; - this->dtype = instr.dtype; + this->alloc_tensor.ndim = instr.alloc_tensor.ndim; + this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, + instr.alloc_tensor.ndim); + this->alloc_tensor.dtype = instr.alloc_tensor.dtype; + return; + case Opcode::AllocTensorReg: + this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; + this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return; case Opcode::AllocDatatype: this->constructor_tag = instr.constructor_tag; @@ -142,8 +148,14 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->result = instr.result; return *this; case Opcode::AllocTensor: - this->shape_register = instr.shape_register; - this->dtype = instr.dtype; + this->alloc_tensor.ndim = instr.alloc_tensor.ndim; + this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, + instr.alloc_tensor.ndim); + this->alloc_tensor.dtype = instr.alloc_tensor.dtype; + return *this; + case Opcode::AllocTensorReg: + this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; + this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return *this; case Opcode::AllocDatatype: this->constructor_tag = instr.constructor_tag; @@ -203,12 +215,15 @@ Instruction::~Instruction() { case Opcode::Move: case Opcode::Select: case Opcode::Ret: - case Opcode::AllocTensor: + case Opcode::AllocTensorReg: case Opcode::If: case Opcode::LoadConst: case Opcode::GetField: case Opcode::Goto: return; + case Opcode::AllocTensor: + delete this->alloc_tensor.shape; + return; case Opcode::AllocDatatype: delete this->datatype_fields; return; @@ -226,8 +241,7 @@ Instruction::~Instruction() { return; default: std::ostringstream out; - LOG(FATAL) << "Invalid instruction " << static_cast(this->op) - << "\n"; + LOG(FATAL) << "Invalid instruction " << static_cast(this->op); } } @@ -252,12 +266,25 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out return instr; } -Instruction Instruction::AllocTensor(RegName shape_register, DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensor(std::vector shape, DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; - instr.shape_register = shape_register; - instr.dtype = dtype; + instr.alloc_tensor.ndim = shape.size(); + instr.alloc_tensor.shape = new int64_t[shape.size()]; + for (size_t i = 0; i < shape.size(); ++i) { + instr.alloc_tensor.shape[i] = shape[i]; + } + instr.alloc_tensor.dtype = dtype; + return instr; +} + +Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype, Index dst) { + Instruction instr; + instr.op = Opcode::AllocTensorReg; + instr.dst = dst; + instr.alloc_tensor_reg.shape_register = shape_register; + instr.alloc_tensor_reg.dtype = dtype; return instr; } @@ -381,85 +408,92 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { break; } - os << dtype.bits; - if (dtype.lanes != 0) { - os << "[" << dtype.lanes << "]"; + os << int(dtype.bits); + if (dtype.lanes != 1) { + os << "x" << dtype.lanes; } } +template +std::string StrJoin(T* items, int offset, int cnt, std::string delim = ",") { + if (cnt == 0) { + return ""; + } + std::ostringstream oss; + oss << items[offset]; + for (int i = 1; i < cnt; ++i) { + oss << delim << items[offset + i]; + } + return oss.str(); +} + void InstructionPrint(std::ostream& os, const Instruction& instr) { switch (instr.op) { case Opcode::Move: { - os << "move " << instr.from << " " << instr.dst; + os << "move $" << instr.dst << " $" << instr.from; break; } case Opcode::Ret: { - os << "ret " << instr.result; + os << "ret $" << instr.result; break; } case Opcode::InvokePacked: { - os << "invoke_packed "; - os << instr.packed_index; - os << " " << instr.arity; - os << "("; - for (Index i = 0; i < instr.arity; ++i) { - os << instr.packed_args[i] << ","; - } - os << ")"; - os << " " << instr.output_size; + os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $" + << StrJoin(instr.packed_args, 0, instr.arity - instr.output_size, ",$") + << ", out: $" + << StrJoin(instr.packed_args, instr.arity - instr.output_size, + instr.output_size, ",$") + << ")"; break; } case Opcode::AllocTensor: { - os << "alloc_tensor "; - os << instr.dst << " "; - os << instr.shape_register << " "; - DLDatatypePrint(os, instr.dtype); + os << "alloc_tensor $" << instr.dst << " [" + << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) + << "] "; + DLDatatypePrint(os, instr.alloc_tensor.dtype); + break; + } + case Opcode::AllocTensorReg: { + os << "alloc_tensor_reg $" << instr.dst << " $" + << instr.alloc_tensor_reg.shape_register << " "; + DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; } case Opcode::AllocDatatype: { - os << "alloc_data "; - os << instr.dst << " "; - os << instr.constructor_tag << " "; - os << instr.num_fields; + os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" + << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"; break; } case Opcode::AllocClosure: { - os << "alloc_closure "; - os << instr.dst << " "; - os << instr.clo_index << " "; - os << instr.num_freevar << "("; - for (Index i = 0; i < instr.num_freevar; ++i) { - os << instr.free_vars[i] << ","; - } - os << ")"; + os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index + << "]($" << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") + << ")"; break; } case Opcode::If: { - os << "if " - << "$" << instr.if_cond << " " << instr.true_offset << " " << instr.false_offset; + os << "if " << "$" << instr.if_cond << " " << instr.true_offset << " " + << instr.false_offset; break; } case Opcode::Invoke: { - os << "invoke " - << "$" << instr.dst << " " << instr.func_index << " " << instr.num_args << "("; - for (Index i = 0; i < instr.num_args; ++i) { - os << instr.invoke_args_registers[i] << ","; - } - os << ")"; + os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" + << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") + << ")"; break; } case Opcode::InvokeClosure: { - os << "invoke_closure " - << "$" << instr.dst << " " << instr.closure << " " << instr.closure_args_num << "()"; + os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" + << StrJoin(instr.closure_args, 0, instr.closure_args_num, ",$") + << ")"; break; } case Opcode::LoadConst: { - os << "load_const " - << "$" << instr.dst << " " << instr.const_index; + os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"; break; } case Opcode::GetField: { - os << "get_field " << instr.dst << " " << instr.object << " " << instr.field_index; + os << "get_field $" << instr.dst << " $" << instr.object << "[" + << instr.field_index << "]"; break; } case Opcode::Goto: { @@ -467,8 +501,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::Select: { - os << "select " << instr.dst << " " << instr.select_cond << " " << instr.select_op1 << " " - << instr.select_op2; + os << "select $" << instr.dst << " $" << instr.select_cond << " $" + << instr.select_op1 << " $" << instr.select_op2; break; } default: @@ -513,48 +547,64 @@ Index VirtualMachine::PopFrame() { } void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { - DLOG(INFO) << "===================\nInvoking global " << func.name << " " << args.size() - << std::endl; + DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); PushFrame(func.params, this->pc + 1, func); for (size_t i = 0; i < args.size(); ++i) { WriteRegister(i, args[i]); } - DLOG(INFO) << "func.params= " << func.params << std::endl; + DLOG(INFO) << "func.params= " << func.params; code = func.instructions.data(); pc = 0; } Object VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { - DLOG(INFO) << "Executing Function: " << std::endl << func << std::endl; + DLOG(INFO) << "Executing Function: " << std::endl << func; InvokeGlobal(func, args); Run(); auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); - DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B\n"; + DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; return return_register; } Object VirtualMachine::Invoke(const std::string& name, const std::vector& args) { auto func_index = this->global_map_[name]; - DLOG(INFO) << "Invoke Global " << name << " at index " << func_index << std::endl; + DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; return Invoke(this->functions[func_index], args); } void InvokePacked(const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { - std::vector values(arg_count); - std::vector codes(arg_count); - runtime::TVMArgsSetter setter(values.data(), codes.data()); + size_t arity = 0; + for (Index i = 0; i < arg_count; i++) { + if (args[i].ptr_->tag == ObjectTag::kDatatype) { + arity += args[i].AsDatatype()->fields.size(); + } else { + ++arity; + } + } + std::vector values(arity); + std::vector codes(arity); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + int idx = 0; for (Index i = 0; i < arg_count; i++) { - NDArray data = ToNDArray(args[i]); - setter(i, data); + if (args[i].ptr_->tag == ObjectTag::kDatatype) { + auto dt_cell = args[i].AsDatatype(); + for (auto obj : dt_cell->fields) { + NDArray data = ToNDArray(obj); + setter(idx++, data); + } + } else { + NDArray data = ToNDArray(args[i]); + setter(idx++, data); + } } TVMRetValue rv; - func.CallPacked(TVMArgs(values.data(), codes.data(), arg_count), &rv); + func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } void VirtualMachine::Init(const std::vector& ctxs) { this->ctxs = ctxs; } @@ -574,7 +624,7 @@ void VirtualMachine::Run() { while (true) { main_loop: auto const& instr = this->code[this->pc]; - DLOG(INFO) << "\nExecuting(" << pc << "): "; + DLOG(INFO) << "Executing(" << pc << "): "; #if USE_RELAY_DEBUG InstructionPrint(std::cout, instr); #endif // USE_RELAY_DEBUG @@ -669,11 +719,23 @@ void VirtualMachine::Run() { goto main_loop; } case Opcode::AllocTensor: { + auto shape = std::vector(instr.alloc_tensor.ndim); + for (uint i = 0; i < instr.alloc_tensor.ndim; ++i) { + shape[i] = instr.alloc_tensor.shape[i]; + } + auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); + auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); + auto obj = Object::Tensor(data); + WriteRegister(instr.dst, obj); + pc++; + goto main_loop; + } + case Opcode::AllocTensorReg: { DLContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - auto shape_tensor_obj = ReadRegister(instr.shape_register); + auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = ToNDArray(shape_tensor_obj).CopyTo(cpu_ctx); int64_t* dims = static_cast(shape_tensor->data); @@ -681,7 +743,7 @@ void VirtualMachine::Run() { auto shape = std::vector(shape_tensor->shape[0]); shape.assign(dims, dims + num_dims); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); - auto data = allocator->Empty(shape, instr.dtype, ctxs[0]); + auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto obj = Object::Tensor(data); WriteRegister(instr.dst, obj); pc++; From b8b93f8d762a7977c6da7cce2caee38377be195c Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 14 Jun 2019 21:34:37 -0700 Subject: [PATCH 141/176] Fix typo in word explicitly (#3376) --- nnvm/python/nnvm/frontend/caffe2.py | 2 +- nnvm/python/nnvm/frontend/onnx.py | 2 +- nnvm/python/nnvm/frontend/tensorflow.py | 2 +- nnvm/tests/python/frontend/mxnet/test_forward.py | 2 +- python/tvm/expr.py | 2 +- python/tvm/relay/frontend/caffe2.py | 2 +- python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 2 +- tests/python/frontend/mxnet/test_forward.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/nnvm/python/nnvm/frontend/caffe2.py b/nnvm/python/nnvm/frontend/caffe2.py index 2b3ff5a27e01..f951db66b5a6 100644 --- a/nnvm/python/nnvm/frontend/caffe2.py +++ b/nnvm/python/nnvm/frontend/caffe2.py @@ -411,7 +411,7 @@ def _convert_operator(self, identity_list=None, convert_map=None): """Convert from Caffe2 operator to nnvm operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index c8b050ad2343..b5e294b97fb1 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -963,7 +963,7 @@ def _convert_operator(self, identity_list=None, convert_map=None): """Convert from onnx operator to nnvm operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index e59a4e76c465..7b4147155d93 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1550,7 +1550,7 @@ def _convert_rnn_operator(self, op_name, inputs, def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to nnvm operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index db5534daee1a..446ebebbfc5a 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -137,7 +137,7 @@ def test_forward_fc_flatten(): def test_forward_clip(): data = mx.sym.var('data') - data = mx.sym.concat(data, -data, dim=1) # negative part explicity + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.clip(data, a_min=0, a_max=1) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index b4588e5d971a..9c8a9ab89d3b 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -222,7 +222,7 @@ def asnode(self): class Expr(ExprOp, NodeBase): """Base class of all tvm Expressions""" - # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__ + # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = NodeBase.__hash__ diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index e92a6226072f..eb8e717bb343 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -505,7 +505,7 @@ def _convert_operator(self, identity_list=None, convert_map=None): """Convert from Caffe2 operator to Relay operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 18253e498560..98ff10bd8318 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1154,7 +1154,7 @@ def _convert_operator(self, attrs, opset): """Convert ONNX operator into a Relay operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7319d5eb4a7e..866a6228980e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2277,7 +2277,7 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 8d7c15bb0be5..45e2ab58cae3 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -143,7 +143,7 @@ def test_forward_fc_flatten(): def test_forward_clip(): data = mx.sym.var('data') - data = mx.sym.concat(data, -data, dim=1) # negative part explicity + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.clip(data, a_min=0, a_max=1) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) From 964c2602c7df1d8667a22728867d27e76facbeac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Sat, 15 Jun 2019 15:08:46 -0700 Subject: [PATCH 142/176] save (#3033) save save save upstream lint remove bad changes fix build save save please the ci god Update src/relay/pass/partial_eval.cc Co-Authored-By: Wei Chen save fix test ci is ANGRY fix rebase problem fix rebase add test save save comment --- include/tvm/relay/pass.h | 11 +- include/tvm/relay/transform.h | 4 +- python/tvm/relay/ir_pass.py | 77 ++-- src/relay/ir/expr.cc | 6 +- src/relay/pass/dead_code.cc | 28 +- src/relay/pass/partial_eval.cc | 399 ++++++++++++++++--- tests/python/relay/test_pass_partial_eval.py | 270 +++++++++---- 7 files changed, 624 insertions(+), 171 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 977bb6793bb5..fff630f55eb7 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -296,13 +296,15 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); * For example, this pass should turn `let a = 1 in 2` into `2`, * as the value of the expression does not depend on a. * - * As another example, `let a = 1 in a` will be optimized into 1. + * As another example, `let a = 1 in a` will be optimized into 1, + * if the flag is turned on. * * \param e the expression to optimize. + * \param inline_once whether or not to inline binding used one. * * \return the optimized expression. */ -TVM_DLL Expr DeadCodeElimination(const Expr& e); +TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false); /*! * \brief Fold constant expressions. @@ -435,11 +437,12 @@ TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * As a side effect, code size will explode. * - * \param e the expression, + * \param e the expression + * \param mod the module * * \return the optimized expression. */ -TVM_DLL Expr PartialEval(const Expr& e); +TVM_DLL Expr PartialEval(const Expr& e, const Module& mod); /*! * \brief Bind the free variables to a Relay expression. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index f579f1c7ba91..fb8ebbf09946 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< * * As another example, `let a = 1 in a` will be optimized into 1. * + * \param inline_once whether or not to inline binding used one. + * * \return the pass. */ -TVM_DLL Pass DeadCodeElimination(); +TVM_DLL Pass DeadCodeElimination(bool inline_once = false); /*! * \brief Fold constant expressions. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8f1ceded76dd..dd0f54c664ca 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -129,7 +129,7 @@ def well_formed(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -175,7 +175,7 @@ def free_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -197,7 +197,7 @@ def bound_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -213,7 +213,7 @@ def all_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + + mod : Optional[tvm.relay.Module] The global module Returns @@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + + mod : Optional[tvm.relay.Module] The global module Returns @@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + mod : Optional[tvm.relay.Module] The global module Returns @@ -286,12 +288,12 @@ def simplify_inference(expr): Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression which is semantically equal to the input expression, but with some simplification """ @@ -304,32 +306,34 @@ def canonicalize_ops(expr): Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression without bias_add """ return _ir_pass.canonicalize_ops(expr) -def dead_code_elimination(expr): +def dead_code_elimination(expr, inline_once=False): """ Remove expressions which does not effect the program result (dead code). Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression + inline_once : Optional[Bool] + Whether to inline binding that occur only once. Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression which is semantically equal to the input expression, but with dead code removed. """ - return _ir_pass.dead_code_elimination(expr) + return _ir_pass.dead_code_elimination(expr, inline_once) def alpha_equal(lhs, rhs): @@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs): Parameters ---------- - lhs: tvm.relay.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs: tvm.relay.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns ------- - result: bool + result : bool True iff lhs is alpha equal to rhs. """ return bool(_make._alpha_equal(lhs, rhs)) @@ -359,15 +363,15 @@ def graph_equal(lhs, rhs): Parameters ---------- - lhs: tvm.relay.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs: tvm.relay.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns ------- - result: bool + result : bool True iff lhs is data-flow equivalent to rhs. """ return bool(_make._graph_equal(lhs, rhs)) @@ -378,12 +382,12 @@ def structural_hash(value): Parameters ---------- - expr: tvm.relay.Expr or tvm.relay.Type + expr : Union[tvm.relay.Expr, tvm.relay.Type] The expression to hash. Returns ------- - result: int + result : int The hash value """ if isinstance(value, Expr): @@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None): expr : tvm.relay.Expr The input expression. - mod: Optional[tvm.relay.Module] + mod : Optional[tvm.relay.Module] The global module. Returns ------- - expr: tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ return _ir_pass.to_a_normal_form(expr, mod) @@ -563,7 +567,7 @@ def to_graph_normal_form(expr): The input expression Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression """ return _ir_pass.to_graph_normal_form(expr) @@ -612,7 +616,7 @@ def get_total_mac_number(expr): Returns ------- - ret : int64 + result : int64 The number of MACs (multiply-accumulate) of a model """ return _ir_pass.GetTotalMacNumber(expr) @@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None): expr : tvm.relay.Expr The input expression. - fskip: function + fskip : function The callback function that decides whether an expression should be skipped. Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ return _ir_pass.eliminate_common_subexpr(expr, fskip) -def partial_evaluate(expr): +def partial_evaluate(expr, mod=None): """ Evaluate the static fragment of the code. @@ -646,12 +650,15 @@ def partial_evaluate(expr): expr : tvm.relay.Expr The input expression. + mod : Optional[tvm.relay.Module] + The global module + Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ - return _ir_pass.partial_evaluate(expr) + return _ir_pass.partial_evaluate(expr, mod) def unmatched_cases(match, mod=None): """ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 64706933fde3..e0ec10a87061 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; }); Let LetNode::make(Var var, Expr value, Expr body) { @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_API("relay._expr.TempExprRealize") .set_body_typed([](TempExpr temp) { - return temp->Realize(); + return temp->Realize(); }); } // namespace relay diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index be6774564806..7e186f80df92 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -38,10 +38,10 @@ namespace relay { // calculate the dependency graph from expression class CalcDep : private ExprVisitor { public: - static Expr Eliminate(const Expr& e) { + static Expr Eliminate(const Expr& e, bool inline_once) { CalcDep cd; cd.Calculate(e); - Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); + Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once); return el(e); } @@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor { VarMap expr_map_; VarMap use_map_; VarSet letrec_set_; + bool inline_once_; explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, - const VarSet& letrec_set) : - expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } + const VarSet& letrec_set, + bool inline_once) : + expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { } friend CalcDep; bool HasLet(const Var& v) { - // TODO(@jroesch): MK fix me - return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); + switch (use_map_[v]) { + case 0: + return false; + case 1: + return letrec_set_.count(v) > 0 || !inline_once_; + default: + return true; + } } Expr VisitExpr_(const VarNode* op) final { @@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor { }; }; -Expr DeadCodeElimination(const Expr& e) { - return CalcDep::Eliminate(e); +Expr DeadCodeElimination(const Expr& e, bool inline_once) { + return CalcDep::Eliminate(e, inline_once); } TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") @@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") namespace transform { -Pass DeadCodeElimination() { +Pass DeadCodeElimination(bool inline_once) { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(DeadCodeElimination(f)); + return Downcast(DeadCodeElimination(f, inline_once)); }; return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 71ba7cd11cd5..07ec1b0711ae 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -74,28 +74,19 @@ * * The partial evaluator makes several assumptions, so there is room for improvement: * - * 0: The partial evaluator treats global variables as opaque. - * Doing PartialEval on a module level will solve this. - * - * 1: The partial evaluator assume all functions as terminating. - * We need to has a max_expand parameter that shrink on every compile time evaluation, - * to make sure PE does not infinite loop. - * Additionally, we might add a termination analysis pass that lift this requirement - * for function that analysis found terminating. - * - * 2: Every time an unknown effect happened, we clear the whole store. + * 0: Every time an unknown effect happened, we clear the whole store. * It is too conservative: if a local reference is created (and do not get passed outside), * An unknown global function call/global reference write can not modify it. * We can pair PE with escape analysis/alias analysis. * - * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise. + * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise. * - * 4: When doing pattern matching, we can simplify the match even for dynamic case. + * 2: When doing pattern matching, we can simplify the match even for dynamic case. * Right now it is all or nothing: either a complete match, or the original dynamic code. * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. * We then can reify the result. * - * 5: Every time a function is called, it's code will get expanded and partially evaluated. + * 3: Every time a function is called, its code will get expanded and partially evaluated. * We can do a binding time analysis to cache the result and avoid re-partial evaluation. * * These assumptions do not affect the correctness of the algorithm, however. @@ -104,6 +95,7 @@ #include #include #include +#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" @@ -132,6 +124,8 @@ struct VarEqual { } }; +Expr PostProcess(const Expr&); + /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: @@ -150,10 +144,20 @@ class Static : public NodeRef { using ContainerType = StaticNode; }; +using Time = size_t; + struct PStaticNode : Node { + static Time time() { + static Time time_ = 0; + Time ret = time_; + time_++; + return ret; + } Static pstatic; // may be null Expr dynamic; - PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + Time created_time; + PStaticNode(const Static& pstatic, const Expr& dynamic) : + pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -341,6 +345,7 @@ class Store { }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { + CHECK(stat.defined()); return PStatic(make_node(stat, dynamic)); } @@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() { return CreateInterpreter(Module(nullptr), CPUContext(), target); } +bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + +using FuncId = int; + +/*! + * \brief Annotate a function with a FuncId. + */ +struct WithFuncIdAttrs : public tvm::AttrsNode { + FuncId fid; + + TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { + TVM_ATTR_FIELD(fid) + .describe("The FuncId that an function is annotated with.") + .set_default(-1); + } +}; + +TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); + +Op WithFuncIdOp() { + static const Op& op = Op::Get("annotation.with_funcid"); + return op; +} + +Expr MkWithFuncId(const Expr& expr, FuncId fid) { + auto attrs = make_node(); + attrs->fid = fid; + return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("annotation.with_funcid") +.describe(R"code(Annotate a function with a funcid.)code" +TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("func", "Function", "The input data."); + +Expr StripWithFuncId(const Expr& e); + +Expr DeDup(const Expr& e); + +Function AsFunc(const Expr& e) { + if (e.as()) { + return Downcast(e); + } else if (const CallNode* c = e.as()) { + CHECK(c->op.same_as(WithFuncIdOp())); + CHECK_EQ(c->args.size(), 1); + return AsFunc(c->args[0]); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } +} + class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars) { + PartialEvaluator(const tvm::Array& free_vars, + const Module& mod) : + mod_(mod) { for (const Var& v : free_vars) { env_.Insert(v, NoStatic(v)); } } + PStatic VisitExpr(const Expr& e, LetList* ll) final { + PStatic ret = ExprFunctor::VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - return NoStatic(GetRef(op)); + GlobalVar gv = GetRef(op); + if (gv_map_.count(gv) == 0) { + if (mod_.defined()) { + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); + } else { + gv_map_.insert({gv, NoStatic(gv)}); + } + } + return gv_map_.at(gv); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { @@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const CallNode* op, LetList* ll) final { + if (op->op.same_as(WithFuncIdOp())) { + CHECK_EQ(op->args.size(), 1); + return VisitExpr(op->args[0], ll); + } PStatic f = VisitExpr(op->op, ll); std::vector x; tvm::Array x_dyn; @@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - Function func = GetRef(op); + struct TimeFrame { + PartialEvaluator* pe_; + FuncId fid_; + std::vector