Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Training] Add gradient for Crossentropy #3925

Merged
merged 4 commits into from
Oct 5, 2019

Conversation

MarisaKirisame
Copy link
Contributor

@vinx13 @junrushao1994 @SWu can you guys help review?

Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except some minor

python/tvm/relay/testing/__init__.py Show resolved Hide resolved
@@ -717,3 +717,16 @@ def schedule_bitserial_dense(attrs, outputs, target):


reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove extra blank line (only two are needed)

@MarisaKirisame
Copy link
Contributor Author

@vinx13 I had address your comment. can you review again?

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@@ -1621,3 +1621,7 @@ def bitserial_dense(data,
"""
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
pack_dtype, out_dtype, unipolar)


def cross_entropy(predictions, targets):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a docstring. You should also mention that this is cross-entropy without softmax, as many frameworks equate cross-entropy to cross-entropy from logits

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarisaKirisame can you react on @SWu 's comment (also put it in REGISTER_RELAY_OP section)

@MarisaKirisame
Copy link
Contributor Author

@vinx13 @tqchen what should it's schedule be?



RELAY_REGISTER_OP("nn.cross_entropy")
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE)
.describe(R"code(Computes cross entropy given predictions and targets.)code" TVM_ADD_FILELINE)

<< "y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< "x shape=, " << x->shape
<< "x shape = " << x->shape << ", "

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and can this be done for all of the above instances?

@reg.register_compute("nn.cross_entropy")
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y / x.shape[0])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more efficient and numerically stable to divide by the batch size after the sum?

@vinx13
Copy link
Member

vinx13 commented Sep 15, 2019

@MarisaKirisame The schedule should be injective, can you check if the CUDA schedule are properly called?

@MarisaKirisame
Copy link
Contributor Author

@vinx13 how can I do that? I am not really familiar with tvm low level internal.

@vinx13
Copy link
Member

vinx13 commented Sep 19, 2019

@MarisaKirisame I will take a look

@vinx13
Copy link
Member

vinx13 commented Sep 19, 2019

@MarisaKirisame I guess this is caused by the use of reference. It makes fusion and scheduling difficult. But I didn't reproduce the error on master, can you try rebasing?

@@ -745,3 +745,15 @@ def schedule_bitserial_dense(attrs, outputs, target):


reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


reg.register_schedule("nn.cross_entropy", schedule_injective)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the schedule actually should be schedule_reduce (in relay.op._reduce)

from tvm.relay.testing import check_grad


def test_crossentropy_grad():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: test_cross_entropy_grad

@MarisaKirisame
Copy link
Contributor Author

@vinx13 @tqchen it is still the same error. as I dont has a cuda I cannot reproduce.

@vinx13
Copy link
Member

vinx13 commented Sep 19, 2019

@MarisaKirisame see my comment, the schedule should be reduce

@vinx13
Copy link
Member

vinx13 commented Sep 28, 2019

ping @MarisaKirisame

@MarisaKirisame
Copy link
Contributor Author

@vinx13 sorry, I was pushing training on a private branch. I had addressed the issues.

@MarisaKirisame
Copy link
Contributor Author

@vinx13 @tqchen can anyone take a look? I didnt change resize in any way. this is blocking training.

@vinx13
Copy link
Member

vinx13 commented Oct 3, 2019

@MarisaKirisame might be a flaky case, can you restart the ci?

@vinx13 vinx13 self-assigned this Oct 3, 2019
@MarisaKirisame
Copy link
Contributor Author

@vinx13 I will restart it right now. Just FYI I also got the same error last time.

save

redo max test

save

address comment

fix
@vinx13
Copy link
Member

vinx13 commented Oct 3, 2019

@MarisaKirisame you can try increasing rtol of the failing test

@MarisaKirisame
Copy link
Contributor Author

@vinx13 it now work.

@vinx13
Copy link
Member

vinx13 commented Oct 4, 2019

@MarisaKirisame
Copy link
Contributor Author

@vinx13 I had acted on the comment.

@vinx13 vinx13 merged commit 7d71dd8 into apache:master Oct 5, 2019
@MarisaKirisame MarisaKirisame deleted the crossentropy branch October 5, 2019 01:30
anijain2305 pushed a commit to anijain2305/tvm that referenced this pull request Oct 17, 2019
* save

save

redo max test

save

address comment

fix

* address comment

* increase rtol

* address review comment
wweic pushed a commit to neo-ai/tvm that referenced this pull request Oct 18, 2019
* save

save

redo max test

save

address comment

fix

* address comment

* increase rtol

* address review comment
petrex added a commit to petrex/tvm that referenced this pull request Oct 29, 2019
* master: (21 commits)
  [Fix][VM] Fix VM invoke with set_params (apache#4079)
  [QNN] Refactor fixed point multiplication in requantize (apache#4073)
  Fix match case in Python-side expr functor (apache#4037)
  Hide symbols from dependent libraries if HIDE_PRIVATE_SYMBOLS is ON. (apache#4041)
  Add gradient for log-softmax (apache#4069)
  [DOC] Fix typos in tutorials (apache#4066)
  dicrease the complexity of CalcDep from exponential to linear (apache#4053)
  [Relay][AlterOp] Minor refactor. (apache#4064)
  [Relay][AlterOp] Improving support for broadcast layout alteration. (apache#4040)
  Add parses support for zeros_like tflite operator (apache#4042)
  [Bugfix][TF] reset graph after getting tag of savedmodel (apache#4055)
  [Relay][VM] Add more passes to VMCompiler (apache#4058)
  [Relay][VM] Add autotvm context when compile (apache#4062)
  [Bugfix] Fix target host for vm compiler (apache#4057)
  [Relay][Training] Add gradient for Crossentropy (apache#3925)
  [llvm] switch to use Align for llvm trunk (apache#4051)
  [Relay][TopHub] Add switch to disable TopHub download (apache#4015)
  [Relay][Op] Add instance norm op (apache#4004)
  [QNN][Relay] Calling Dialect passes from inside Relay Build API. (apache#3971)
  [RELAY/PASS] Fix the extent for the post_stmt in the loop partition (apache#3734)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants