Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay][TensorFlow] Add OneHot operator #3781

Merged
merged 19 commits into from
Aug 22, 2019

Conversation

soiferj
Copy link
Contributor

@soiferj soiferj commented Aug 15, 2019

Add one-hot operator where on_value is 1 and off_value is 0. This implementation is enough to support most scenarios. We can add support for on_value and off_value at a later time if it is needed.

@jroesch @icemelon9 would you be able to take a look?

adding @Laurawly and @Huyuwei

@soiferj soiferj changed the title [Topi][Relay][TensorFlow] Add OneHot operator [TOPI][Relay][TensorFlow] Add OneHot operator Aug 15, 2019
Copy link
Member

@icemelon icemelon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some extra changes needed:

  • Add test case to tests/python/relay/test_op_level10.py
  • Update doc in docs/langref/relay_op.rst and docs/api/python/topi.rst


indices = [1., 2., 3.]

relay.one_hot(indices, 2) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the depth be 3 in this example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right, I'll fix that.

--------
.. code-block:: python

indices = [1., 2., 3.]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should indices be [0, 1, 2] in this example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

}

auto idx = iter_vars[iter_vars.size() - 1];
return tvm::if_then_else(indices(outer_indices) == idx, 1, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use ir::Select::make here for better performance

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you require indices to be int type? Otherwise, this line could fail if indices.dtype is float32 without type cast.
You should add a test case for indices with float32 type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on how to use ir::Select::make here? And yes, indices must be an int type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just replace tvm::if_then_else by ir::Select::make. You can find use case at https://github.com/dmlc/tvm/blob/master/topi/include/topi/nn/pooling.h#L265.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@icemelon
Copy link
Member

Could you also add on_value, off_value, axis, dtype in the OneHotAttrs same as TF definition and test case correspondingly?

@soiferj
Copy link
Contributor Author

soiferj commented Aug 19, 2019

@icemelon9 sure, I can do that. The tricky part about on_value and off_value is that it can be either a float or an int, meaning that OneHotAttrs will not necessarily have the correct data type. What do you think is the best way to handle this? Should I have the on_value and off_value be of type float in OneHotAttrs, and do any casting in the TOPI definition?

@icemelon
Copy link
Member

icemelon commented Aug 19, 2019

You can just cast it to the dtype in the attr, and put on_value and off_value as Expr in argument not in the attribute. You can check out implementation of full op:

@soiferj
Copy link
Contributor Author

soiferj commented Aug 20, 2019

Added support for on_value, off_value, axis, and dtype. Let me know what you think!

@soiferj
Copy link
Contributor Author

soiferj commented Aug 22, 2019

@icemelon9 would you mind looking at the updates?

@icemelon
Copy link
Member

lgtm
Could you also help add the converter for mxnet?

@soiferj
Copy link
Contributor Author

soiferj commented Aug 22, 2019

Added to the mxnet frontend

@icemelon icemelon merged commit 554df21 into apache:master Aug 22, 2019
@icemelon
Copy link
Member

Thanks @soiferj @vinx13 !

@soiferj soiferj deleted the soiferj/one_hot branch August 22, 2019 21:12
@lixiaoquan
Copy link
Contributor

It seems not necessary to pass on_value and off_value as tvm.relay.Expr, why not just treat them as attributes? That way, we can avoid allocating two more tensors.

@soiferj
Copy link
Contributor Author

soiferj commented Aug 26, 2019

The on_value and off_value can be any data type. Treating them as tvm.relay.Expr makes the casting to the correct data type easier.

wweic pushed a commit to wweic/tvm that referenced this pull request Sep 16, 2019
* Add one-hot to Relay

* topi implementation

* Working

* add topi test

* Add TF test

* Fix check

* fix linting issues

* fix documentation

* Fix documentation

* Add support for on_value, off_value, axis, dtype

* Add full support for axis

* Fix compute and update test_forward

* Move on_value and off_value to inputs

* Add topi test

* Update tests

* Update docs

* Fix style

* re-enable tests

* Add one_hot to mxnet converter
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 16, 2019
* Add one-hot to Relay

* topi implementation

* Working

* add topi test

* Add TF test

* Fix check

* fix linting issues

* fix documentation

* Fix documentation

* Add support for on_value, off_value, axis, dtype

* Add full support for axis

* Fix compute and update test_forward

* Move on_value and off_value to inputs

* Add topi test

* Update tests

* Update docs

* Fix style

* re-enable tests

* Add one_hot to mxnet converter
wweic pushed a commit to neo-ai/tvm that referenced this pull request Sep 16, 2019
* Add one-hot to Relay

* topi implementation

* Working

* add topi test

* Add TF test

* Fix check

* fix linting issues

* fix documentation

* Fix documentation

* Add support for on_value, off_value, axis, dtype

* Add full support for axis

* Fix compute and update test_forward

* Move on_value and off_value to inputs

* Add topi test

* Update tests

* Update docs

* Fix style

* re-enable tests

* Add one_hot to mxnet converter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants