-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] Add embedding op and gradient #6794
Conversation
Thanks @tkonolige , given that this is a new op, it would be great to do a API review per https://tvm.apache.org/docs/contribute/code_review.html#deliberate-on-api-and-data-structures In particular, it would be great to checkout the convention of similar APIs in existing frameworks like numpy, PyTorch, Keras, TensorFlow, we should ideally follow common conventions. See previous related topics on nms #2535 |
re: API review
NamingAll of the above APIs call it ArgumentsI don't think we need to pass in the vocabulary size or embedding dimension like these examples do, since we can infer it from the weight/table matrix (I imagine they use it for bookkeeping in training, which is a separate matter). Likewise, we can ignore anything related to weight initialization. PyTorch has the following additional arguments:
mxnet has:
TF/keras has:
In my opinion, we should aim for PyTorch's API over TF/Keras, but perhaps others can give more insight. We are also thinking about adding sparse gradient support, so it may be best to add it as an ShapesPyTorch and mxnet support arbitrary input shape. In particular, if our embedding dimension is TF/Keras is strange as they have This PR currently proposes |
cc @antinucleon |
For this PR, we are just going to do the dense gradient. The sparse gradient will take some work, so we will add it in a latter PR. |
""" | ||
s = te.create_schedule([outs[0].op]) | ||
|
||
vec_size = 8 # should autotune this, but we can't with hybrid script |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask why 8? I am just wondering if we could reuse this schedule for the arm back-end as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could reuse it. I just didn't have a good way to figure out the width of the vector instructions.
The embed op is a specialization of take with a 2D lookup table.
This PR adds the embed op and its gradient. Embed is a specialization of take with a 2D lookup table.
@altanh