-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][RFC] Implement type checking for Any #3221
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass on debugging prints mostly. overall lgtm.
DataType dtype; | ||
|
||
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") { | ||
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing the defaults?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't easily make Relay constant values in C++.
Remove code generation related changes Remove compile changes Remove more Remove unification hack Add some code back that was needed, and clean up test Refactor test cases WIP Implement TypeHint AST Add test case which should fail Remove unification changes, and fix bug with let rec Restore unification for shapes Improve error reporting while debugging All examples type check All examples type check WIP First version that works with hints, needs clean up Remove dead code Tweaks Remove type hint Remove unecessary type hint stuff Remove more type hints Clean up Expose Any expression node Address CR Fix Fix solver Kill unecessary code Fix PyLint Fix Relocate loops Fix license and test Lint again Lint again Fix loops Fix docstring Fix template error Fix compiler issue Fix compile err Remove more runtime changes Restore buffer Fix segfault Fix Fix arange
|
||
auto mod = [solver](std::string name) -> PackedFunc { | ||
auto mod = [solver, err_reporter](std::string name) -> PackedFunc { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BAD! This will never get free!
use a shared_ptr!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its fine, this is testing code dude
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm.
CHECK_EQ(ndim(), indices.size()) | ||
<< "Tensor dimension mismatch in read" | ||
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); | ||
if (ndim() != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add ndim() != 0
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to allow retrieval of the "first" element from scalars.
@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc< | |||
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call, | |||
const Expr& output_grad)>; | |||
|
|||
/*! | |||
* \brief The codegeneration strategy for dynamic dimensions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code generation
@@ -664,7 +664,9 @@ class PrettyPrinter : | |||
Doc PrintAttr(const NodeRef& value, bool meta = false) { | |||
if (value.defined()) { | |||
Doc printed_attr; | |||
if (meta) { | |||
if (value.as<tvm::ir::Any>()) { | |||
printed_attr << "?"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this be ambiguous? why not print "any"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of like ?
instead of any
because the shape will still be checked dynamically, so even though we're calling the node "any," not literally any shape will work
func = relay.Function([start], relay.TupleGetItem(body, 1)) | ||
func = infer_type(func) | ||
# TODO(@jroesch, @haichen): We should restore this code when codegeneration | ||
# is merged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a PR we can reference?
* Implement type checking for Any Remove code generation related changes Remove compile changes Remove more Remove unification hack Add some code back that was needed, and clean up test Refactor test cases WIP Implement TypeHint AST Add test case which should fail Remove unification changes, and fix bug with let rec Restore unification for shapes Improve error reporting while debugging All examples type check All examples type check WIP First version that works with hints, needs clean up Remove dead code Tweaks Remove type hint Remove unecessary type hint stuff Remove more type hints Clean up Expose Any expression node Address CR Fix Fix solver Kill unecessary code Fix PyLint Fix Relocate loops Fix license and test Lint again Lint again Fix loops Fix docstring Fix template error Fix compiler issue Fix compile err Remove more runtime changes Restore buffer Fix segfault Fix Fix arange * Address feedback * Fix typo * Fix arange * Fix op level3 * Fix issue with Python wrapper
* Implement type checking for Any Remove code generation related changes Remove compile changes Remove more Remove unification hack Add some code back that was needed, and clean up test Refactor test cases WIP Implement TypeHint AST Add test case which should fail Remove unification changes, and fix bug with let rec Restore unification for shapes Improve error reporting while debugging All examples type check All examples type check WIP First version that works with hints, needs clean up Remove dead code Tweaks Remove type hint Remove unecessary type hint stuff Remove more type hints Clean up Expose Any expression node Address CR Fix Fix solver Kill unecessary code Fix PyLint Fix Relocate loops Fix license and test Lint again Lint again Fix loops Fix docstring Fix template error Fix compiler issue Fix compile err Remove more runtime changes Restore buffer Fix segfault Fix Fix arange * Address feedback * Fix typo * Fix arange * Fix op level3 * Fix issue with Python wrapper
* Implement type checking for Any Remove code generation related changes Remove compile changes Remove more Remove unification hack Add some code back that was needed, and clean up test Refactor test cases WIP Implement TypeHint AST Add test case which should fail Remove unification changes, and fix bug with let rec Restore unification for shapes Improve error reporting while debugging All examples type check All examples type check WIP First version that works with hints, needs clean up Remove dead code Tweaks Remove type hint Remove unecessary type hint stuff Remove more type hints Clean up Expose Any expression node Address CR Fix Fix solver Kill unecessary code Fix PyLint Fix Relocate loops Fix license and test Lint again Lint again Fix loops Fix docstring Fix template error Fix compiler issue Fix compile err Remove more runtime changes Restore buffer Fix segfault Fix Fix arange * Address feedback * Fix typo * Fix arange * Fix op level3 * Fix issue with Python wrapper
Currently a draft PR, see related RFC #3042.
This PR will only contain the type checking changes to Relay to support Any. @icemelon9 and I will follow up with the related code generation PRs.