-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(losses): add Dice loss implementation #19409
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #19409 +/- ##
==========================================
+ Coverage 75.97% 75.98% +0.01%
==========================================
Files 366 366
Lines 40742 40759 +17
Branches 7945 7946 +1
==========================================
+ Hits 30954 30971 +17
Misses 8075 8075
Partials 1713 1713
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
keras/losses/losses.py
Outdated
|
||
|
||
@keras_export("keras.losses.dice") | ||
def dice(y_true, y_pred, smooth=1e-6): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove the smooth
argument.
keras/losses/losses.py
Outdated
|
||
intersection = ops.sum(ops.dot(inputs, targets)) | ||
dice = ops.divide( | ||
2.0 * intersection + smooth, ops.sum(y_true) + ops.sum(y_pred) + smooth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead smooth
, use backend.epsilon()
. Only use it for the denominator.
keras/losses/losses.py
Outdated
Returns: | ||
Dice loss value. | ||
""" | ||
y_true = ops.cast(y_true, dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to force the use of float32
, you can just use ops.convert_to_tensor(y_true)
keras/losses/losses.py
Outdated
Dice loss value. | ||
""" | ||
y_true = ops.cast(y_true, dtype="float32") | ||
y_pred = ops.cast(y_pred, dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, y_pred.dtype)
@fchollet Thanks for reviewing 👍 |
Shouldn't it be in keras-cv? |
keras/losses/losses.py
Outdated
inputs = ops.reshape(y_true, [-1]) | ||
targets = ops.reshape(y_pred, [-1]) | ||
|
||
intersection = ops.sum(ops.dot(inputs, targets)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easier to replace dot
with *
here (doesn't change numerics)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
* Refactor dtypes in codebase and add float8_* dtypes * Update comments Fix for JAX export on GPU. (keras-team#19404) Fix formatting in export_lib. (keras-team#19405) `ops/numpy.py`: Support `key` as `list` in `GetItem` (keras-team#19310) When loading a model that contains `GetItem` nodes with multidimensional indices/slices as `key`, the `key` argument is loaded from JSON as a `list`, not a `tuple` (because JSON does not have the distinction). So, treat the `key list` as equivalent to the `key tuple`. Copying is important: otherwise, the later `pop()` will remove the bound slice elements from the op itself. `saving/serialization_lib_test.py`: * Add `test_numpy_get_item_layer()`: test for consistent serialization/deserialization of a model which contains `ops.numpy.GetItem`; feat(losses): add Dice loss implementation (keras-team#19409) * feat(losses): add Dice loss implementation * removed smooth parameter and type casting * adjusted casting and dot operator Update casting Bump the github-actions group with 1 update (keras-team#19412) Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action). Updates `github/codeql-action` from 3.24.6 to 3.24.9 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](github/codeql-action@8a470fd...1b1aada) --- updated-dependencies: - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Fix issue with shared layer deserialization Remove dead code in saving lib (keras-team#19415) Remove unused beta param for silu, use torch op directly (keras-team#19417) The beta param was only accepted on the tensorflow/torch backends and not in the `keras.ops` API, nor was it tested. I think best just to ditch, since no one could be relying on it. Fix print_fn for custom function (keras-team#19419) Add fp8 to `EinsumDense` Add test script
Adds Dice class/function implementation to losses.