-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Training] Add gradient for Crossentropy #3925
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except some minor
python/tvm/relay/op/nn/_nn.py
Outdated
@@ -717,3 +717,16 @@ def schedule_bitserial_dense(attrs, outputs, target): | |||
|
|||
|
|||
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) | |||
|
|||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove extra blank line (only two are needed)
@vinx13 I had address your comment. can you review again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@@ -1621,3 +1621,7 @@ def bitserial_dense(data, | |||
""" | |||
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits, | |||
pack_dtype, out_dtype, unipolar) | |||
|
|||
|
|||
def cross_entropy(predictions, targets): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should have a docstring. You should also mention that this is cross-entropy without softmax, as many frameworks equate cross-entropy to cross-entropy from logits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarisaKirisame can you react on @SWu 's comment (also put it in REGISTER_RELAY_OP section)
src/relay/op/nn/nn.cc
Outdated
|
||
|
||
RELAY_REGISTER_OP("nn.cross_entropy") | ||
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE) | |
.describe(R"code(Computes cross entropy given predictions and targets.)code" TVM_ADD_FILELINE) |
src/relay/op/nn/nn.cc
Outdated
<< "y shape=" << y->shape; | ||
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) | ||
<< "CrossEntropy: shapes of x and y is inconsistent, " | ||
<< "x shape=, " << x->shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<< "x shape=, " << x->shape | |
<< "x shape = " << x->shape << ", " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and can this be done for all of the above instances?
python/tvm/relay/op/nn/_nn.py
Outdated
@reg.register_compute("nn.cross_entropy") | ||
def compute_cross_entropy(attrs, inputs, out_dtype, target): | ||
x, y = inputs | ||
return [-topi.sum(topi.log(x) * y / x.shape[0])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more efficient and numerically stable to divide by the batch size after the sum?
@MarisaKirisame The schedule should be injective, can you check if the CUDA schedule are properly called? |
@vinx13 how can I do that? I am not really familiar with tvm low level internal. |
@MarisaKirisame I will take a look |
@MarisaKirisame I guess this is caused by the use of reference. It makes fusion and scheduling difficult. But I didn't reproduce the error on master, can you try rebasing? |
d112e48
to
9f3c850
Compare
python/tvm/relay/op/nn/_nn.py
Outdated
@@ -745,3 +745,15 @@ def schedule_bitserial_dense(attrs, outputs, target): | |||
|
|||
|
|||
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) | |||
|
|||
|
|||
reg.register_schedule("nn.cross_entropy", schedule_injective) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the schedule actually should be schedule_reduce (in relay.op._reduce)
from tvm.relay.testing import check_grad | ||
|
||
|
||
def test_crossentropy_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: test_cross_entropy_grad
@MarisaKirisame see my comment, the schedule should be reduce |
ping @MarisaKirisame |
9f3c850
to
3517d33
Compare
@vinx13 sorry, I was pushing training on a private branch. I had addressed the issues. |
3517d33
to
9680232
Compare
@MarisaKirisame might be a flaky case, can you restart the ci? |
@vinx13 I will restart it right now. Just FYI I also got the same error last time. |
71986be
to
e884dfc
Compare
@MarisaKirisame you can try increasing rtol of the failing test |
@vinx13 it now work. |
@vinx13 I had acted on the comment. |
* save save redo max test save address comment fix * address comment * increase rtol * address review comment
* save save redo max test save address comment fix * address comment * increase rtol * address review comment
* master: (21 commits) [Fix][VM] Fix VM invoke with set_params (apache#4079) [QNN] Refactor fixed point multiplication in requantize (apache#4073) Fix match case in Python-side expr functor (apache#4037) Hide symbols from dependent libraries if HIDE_PRIVATE_SYMBOLS is ON. (apache#4041) Add gradient for log-softmax (apache#4069) [DOC] Fix typos in tutorials (apache#4066) dicrease the complexity of CalcDep from exponential to linear (apache#4053) [Relay][AlterOp] Minor refactor. (apache#4064) [Relay][AlterOp] Improving support for broadcast layout alteration. (apache#4040) Add parses support for zeros_like tflite operator (apache#4042) [Bugfix][TF] reset graph after getting tag of savedmodel (apache#4055) [Relay][VM] Add more passes to VMCompiler (apache#4058) [Relay][VM] Add autotvm context when compile (apache#4062) [Bugfix] Fix target host for vm compiler (apache#4057) [Relay][Training] Add gradient for Crossentropy (apache#3925) [llvm] switch to use Align for llvm trunk (apache#4051) [Relay][TopHub] Add switch to disable TopHub download (apache#4015) [Relay][Op] Add instance norm op (apache#4004) [QNN][Relay] Calling Dialect passes from inside Relay Build API. (apache#3971) [RELAY/PASS] Fix the extent for the post_stmt in the loop partition (apache#3734) ...
@vinx13 @junrushao1994 @SWu can you guys help review?