-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Add shape check for ConcatenateRel and StackRel #3699
Conversation
Would it be possible to have test cases/regressions? |
Yes, I can add test cases. Where can I find some examples (of test cases)? Between, currently I am using |
@slyubomirsky Did you mean tests like https://github.com/dmlc/tvm/tree/master/tests/cpp? |
They could be like those, or tests in Python (see |
Concatenate is a fundamental operator so it should at level 1. Is my understanding correct? |
@slyubomirsky I have just added a test case. Please check. Also, I note there is no check to stack relay. Consider the similar functionality of stack / concatenate, do you think it is necessary to add one? |
tests/python/relay/test_op_level1.py
Outdated
y = relay.var('p2', shape=(2, 3)) | ||
c = relay.concatenate([x, y], axis=0) | ||
func = relay.Function([x, y], c) | ||
relay_module = relay.Module.from_expr(func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you implicitly inferring type here? Should we have some assertions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain more here? Current action here is similar to run_infer_type()
used in other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I am saying that type inference implicit here when creating a module. We probably prefer to invoke run_infer_type
directly and check the expected type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. But one weird thing I met is, if I use run_infer_type()
, then no error is thrown
# no error
x = relay.var('p1', shape=(2, 5))
y = relay.var('p2', shape=(2, 3))
z = relay.concatenate([x, y], axis=1)
zz = run_infer_type(z)
# throws tvm.Error
x = relay.var('p1', shape=(2, 5))
y = relay.var('p2', shape=(2, 3))
z = relay.concatenate([x, y], axis=1)
func = relay.Function([x, y], c)
zz = run_infer_type(func)
While run_infer_type()
is
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
It seems we have to pack it into a relay.Function
then the error can be detected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems very strange, since module.from_expr
is supposed to pack a non-function into a function anyway (see https://github.com/dmlc/tvm/blob/master/src/relay/ir/module.cc#L232).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked out your branch and tried running the example you gave myself and did not get an error in either case, so I'm not sure what problem you might have had. Was there a problem before that you fixed?
Thanks for writing tests, I will look it over when I get a chance. |
Looks good to me, especially after Zhi's recommended changes. I think there should be a test for stack, since it was also modified by these changes (definitely an oversight that there wasn't one already). |
ping @Lyken17 can you look into the CI error? |
@tqchen Sorry the delay I will have a look at the CI error. Seems the change lead to doc build failed on ssd example. But I cannot reproduce on local laptop.
|
could it has something to do with the mxnet version? The error seems does correlates with the concat op if that helps |
@tqchen I have resolved the CI issue. Can you have a check? |
seems tianqi is busy. @zhiics @jroesch @slyubomirsky can u have a check? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* [Relay] add shape check for concat * [Relay] add shape check for stack * add test case for shape mismatch * [typo] add the missing assert * fix lint errors. * replace int with size_t. * statically cast param->axis to size_t. * switch to run_infer_type. * fix checking for negative index * add static_cast for param->axis * merge to latest tvm * fix lint error * Fix an error with negative index. * Update transform.h * Update transform.cc
* [Relay] add shape check for concat * [Relay] add shape check for stack * add test case for shape mismatch * [typo] add the missing assert * fix lint errors. * replace int with size_t. * statically cast param->axis to size_t. * switch to run_infer_type. * fix checking for negative index * add static_cast for param->axis * merge to latest tvm * fix lint error * Fix an error with negative index. * Update transform.h * Update transform.cc
* [Relay] add shape check for concat * [Relay] add shape check for stack * add test case for shape mismatch * [typo] add the missing assert * fix lint errors. * replace int with size_t. * statically cast param->axis to size_t. * switch to run_infer_type. * fix checking for negative index * add static_cast for param->axis * merge to latest tvm * fix lint error * Fix an error with negative index. * Update transform.h * Update transform.cc
* [Relay] add shape check for concat * [Relay] add shape check for stack * add test case for shape mismatch * [typo] add the missing assert * fix lint errors. * replace int with size_t. * statically cast param->axis to size_t. * switch to run_infer_type. * fix checking for negative index * add static_cast for param->axis * merge to latest tvm * fix lint error * Fix an error with negative index. * Update transform.h * Update transform.cc
As discussed in [Relay][Concatenate] Missing shape checking for non-concat axes, it is better to add shape checking in the relay. This pull request aims to complete the missing shape checking for non-concating / non-stacking axes.