-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][Relay][TensorFlow] Add OneHot operator #3781
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some extra changes needed:
- Add test case to tests/python/relay/test_op_level10.py
- Update doc in docs/langref/relay_op.rst and docs/api/python/topi.rst
python/tvm/relay/op/transform.py
Outdated
|
||
indices = [1., 2., 3.] | ||
|
||
relay.one_hot(indices, 2) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the depth be 3 in this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right, I'll fix that.
python/tvm/relay/op/transform.py
Outdated
-------- | ||
.. code-block:: python | ||
|
||
indices = [1., 2., 3.] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should indices
be [0, 1, 2]
in this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
topi/include/topi/transform.h
Outdated
} | ||
|
||
auto idx = iter_vars[iter_vars.size() - 1]; | ||
return tvm::if_then_else(indices(outer_indices) == idx, 1, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use ir::Select::make
here for better performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you require indices to be int type? Otherwise, this line could fail if indices.dtype is float32 without type cast.
You should add a test case for indices with float32 type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand on how to use ir::Select::make here? And yes, indices must be an int type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just replace tvm::if_then_else
by ir::Select::make
. You can find use case at https://github.com/dmlc/tvm/blob/master/topi/include/topi/nn/pooling.h#L265.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
Could you also add |
@icemelon9 sure, I can do that. The tricky part about on_value and off_value is that it can be either a float or an int, meaning that OneHotAttrs will not necessarily have the correct data type. What do you think is the best way to handle this? Should I have the on_value and off_value be of type float in OneHotAttrs, and do any casting in the TOPI definition? |
You can just cast it to the dtype in the attr, and put on_value and off_value as Expr in argument not in the attribute. You can check out implementation of full op: |
Added support for on_value, off_value, axis, and dtype. Let me know what you think! |
@icemelon9 would you mind looking at the updates? |
lgtm |
Added to the mxnet frontend |
It seems not necessary to pass on_value and off_value as tvm.relay.Expr, why not just treat them as attributes? That way, we can avoid allocating two more tensors. |
The on_value and off_value can be any data type. Treating them as tvm.relay.Expr makes the casting to the correct data type easier. |
* Add one-hot to Relay * topi implementation * Working * add topi test * Add TF test * Fix check * fix linting issues * fix documentation * Fix documentation * Add support for on_value, off_value, axis, dtype * Add full support for axis * Fix compute and update test_forward * Move on_value and off_value to inputs * Add topi test * Update tests * Update docs * Fix style * re-enable tests * Add one_hot to mxnet converter
* Add one-hot to Relay * topi implementation * Working * add topi test * Add TF test * Fix check * fix linting issues * fix documentation * Fix documentation * Add support for on_value, off_value, axis, dtype * Add full support for axis * Fix compute and update test_forward * Move on_value and off_value to inputs * Add topi test * Update tests * Update docs * Fix style * re-enable tests * Add one_hot to mxnet converter
* Add one-hot to Relay * topi implementation * Working * add topi test * Add TF test * Fix check * fix linting issues * fix documentation * Fix documentation * Add support for on_value, off_value, axis, dtype * Add full support for axis * Fix compute and update test_forward * Move on_value and off_value to inputs * Add topi test * Update tests * Update docs * Fix style * re-enable tests * Add one_hot to mxnet converter
Add one-hot operator where on_value is 1 and off_value is 0. This implementation is enough to support most scenarios. We can add support for on_value and off_value at a later time if it is needed.
@jroesch @icemelon9 would you be able to take a look?
adding @Laurawly and @Huyuwei