Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
histogram: implement buckets_v3 #5356
histogram: implement buckets_v3 #5356
Changes from 10 commits
05d8c78
3ea37da
f0815dc
6774bea
fffb21b
5bf44ad
3f4cfd3
9e18664
3e72bde
9d32842
72cb90b
2d5fafe
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think the reason this was failing is that in graph mode, when it compiles the TensorFlow code into a graph, all of the branches of the
tf.cond()
s are inserted as parts of the graph (even if those branches won't be taken when the graph is actually executed). When doing that, TensorFlow does shape inference, described a little bit here: https://www.tensorflow.org/guide/create_op#shape_functions_in_cBasically, it tries to compute what the resulting shape will be for all of the ops in all of the branches. At the line where we currently do
tf.fill([bucket_count - 1], 0)
, becausebucket_count
is a constant, it is able to determine at compile time that this will result intf.fill([-1], 0)
and the shape function fortf.fill
emits an error, because this would result in an output shape that doesn't make sense.Hence the error (at least, the one that I was seeing when trying this PR locally):
In other words, it doesn't actually mean it's picking the wrong branch of the conditional to execute; it's that during graph compilation, it visits all the branches regardless of what will be ultimately executed.
(Aside: you could argue that when
bucket_count
is a constant 0, we could actually optimize away all thetf.cond
ops and all the branches besides the empty case, at which point we wouldn't hit this error. I'm guessing that the compilation's constant-folding isn't quite smart enough to do that. Also, if you change the test to passbucket_count=tf.Variable(0)
and replace the use ofmax()
withtf.math.maximum()
as mentioned in the other comment, it will also pass, since now that the bucket count is a variable, it doesn't do as much shape inference and can't raise the false positive error.)Ultimately, the best way to fix this is to just ensure that all the branches are still creating the correctly shaped tensor of dimensions
(bucket_count, 3)
, even when we know that branch will never be executed. Here's a patch that gets the test to pass:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the clear explanation and the patch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use
tf.math.maximum()
here, rather than Python'smax()
, in order to be compatible withbucket_count
values that aren't native Python numbers (e.g. atf.constant()
instead).Also I'd recommend just adding a conversion up at the top that does
bucket_count = tf.math.maximum(0, bucket_count)
so we don't have to worry about negative bucket counts in any of the later logic.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!