Skip to content

Commit

Permalink
Dynamic ONNX Importer (apache#6351)
Browse files Browse the repository at this point in the history
* Change onnx importer to use dynamic upsampling3d (#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <[email protected]>
  • Loading branch information
Matthew Brookhart and Lily Orth-Smith authored Oct 3, 2020
1 parent a413458 commit 2658ebe
Show file tree
Hide file tree
Showing 21 changed files with 957 additions and 489 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ TVM_DLL Pass SimplifyInference();
*/
TVM_DLL Pass FastMath();

/*!
* \brief Find Dynamic ops and make them static
*
* Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
* them with static ops and re-performs type inference and constant folding. The pass repeats
* itself until the graph stops changing or we run too many iterations.
*
* \return The pass.
*/
TVM_DLL Pass DynamicToStatic();

/*!
* \brief Infer the type of an expression.
*
Expand Down
11 changes: 8 additions & 3 deletions include/tvm/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,19 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
<< "\nvs\ninput: " << t;
auto bh = detail::BroadcastShape(output_shape, t->shape);
CHECK_EQ(output_shape.size(), bh.common_shape.size());
Array<PrimExpr> oshape;
for (size_t i = 0; i < output_shape.size(); ++i) {
CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
oshape.push_back(output_shape[i]);
} else {
CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
oshape.push_back(bh.common_shape[i]);
}
}
auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
};
return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
l, name, tag);
return tvm::te::compute(oshape, l, name, tag);
}

#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
Expand Down
Loading

0 comments on commit 2658ebe

Please sign in to comment.