Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dice Loss #371

Closed
wants to merge 23 commits into from
Closed

Add Dice Loss #371

wants to merge 23 commits into from

Conversation

innat
Copy link
Contributor

@innat innat commented Apr 26, 2022

Close #296

image

Ps. Tested on actual training, works fine so far for 2D and 3D cases; now needs some polish.

@qlzh727
Copy link
Member

qlzh727 commented Apr 27, 2022

Thanks. Will take a look. (Do u still want to keep it as a draft or make it a formal PR?)

@innat innat marked this pull request as ready for review April 28, 2022 15:26
@innat
Copy link
Contributor Author

innat commented Apr 28, 2022

@qlzh727 This PR can be reviewed now. Though I will add some more test cases in the dice_test.py, feel free to provide some initial comments. Thanks.

@qlzh727 qlzh727 self-requested a review April 28, 2022 17:24
Copy link
Member

@qlzh727 qlzh727 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, and sorry for the late reply.

0.5549314

>>> # Calling with 'sample_weight'.
>>> dice(y_true, y_pred, sample_weight=tf.constant([[0.5, 0.5]])).numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should provide a sample_weights that is matching to the y_true/pred without boardcast, eg with shape [5] which match to the batch size.

keras_cv/losses/dice.py Outdated Show resolved Hide resolved
keras_cv/losses/dice.py Outdated Show resolved Hide resolved
keras_cv/losses/dice.py Outdated Show resolved Hide resolved
keras_cv/losses/dice.py Outdated Show resolved Hide resolved
per_sample: If `True`, the loss will be calculated for each sample in
batch and then averaged. Otherwise the loss will be calculated for
the whole batch. Default to `False`.
epsilon: Small float added to dice score to avoid dividing by zero.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is probably not needed. We can just use keras.backend.epsilon() directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is the smoothing parameter and sometimes users may want to set their own. WDYT? But it can be removed of course. As a default value, I use backend.epsilon though.

keras_cv/losses/dice.py Show resolved Hide resolved
keras_cv/losses/dice.py Outdated Show resolved Hide resolved
keras_cv/losses/dice.py Outdated Show resolved Hide resolved
keras_cv/utils/fill_utils.py Outdated Show resolved Hide resolved
@innat innat requested a review from qlzh727 May 18, 2022 15:55
@lucasdavid
Copy link

lucasdavid commented May 25, 2022

I can see axis=[1,2] in this and some other classes in this repository.
Don't know if this has been discussed somewhere else before, but I would like to suggest that we assumed all axes contained positional information, except for the batch and channels axes. This way, every single method/procedure would handle multi-dimentional cases (2D, 3D, etc) transparently.

This can be achieved by reshaping the n-d signal into a (batch, positions, channels) one, and reducing over positions:

def some_metric(y_true, y_pred, axis=None):
  if axis is None:
    y_true = squeeze(y_true, dims=3)  # 2d, 3d, 4d, ..., nd -> 3d
    y_pred = squeeze(y_pred, dims=3)  # 2d, 3d, 4d, ..., nd -> 3d
    axis = 1

  # calculate metric result
  # ...
  return tf.reduce_mean(result, axis=axis)

def squeeze(y, dims: int = 3):
  shape = tf.shape(y)

  if dims not in (2, 3):
    raise ValueError(f'Illegal value for parameter dims=`{dims}`. Can only squeeze '
                     'positional signal, resulting in a tensor with rank 2 or 3.')

  new_shape = [shape[0], -1]
  if dims == 3:  # keep channels.
    new_shape += [shape[-1]]
  return tf.reshape(y, new_shape)

@innat innat mentioned this pull request Jun 24, 2022
@LukeWood LukeWood added the awaiting review PRs that need to be reviewed label Jul 14, 2022
axis: An optional sequence of `int` specifying the axis to perform reduce
ops for raw dice score. For 2D model, it should be [1,2] or [2,3]
for the `channels_last` or `channels_first` format respectively. And for
3D mdoel, it should be [1,2,3] or [2,3,4] for the `channels_last` or
Copy link
Contributor

@quantumalaviya quantumalaviya Jul 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be "model" on line 146

@LukeWood
Copy link
Contributor

LukeWood commented Aug 4, 2022

Thanks for your patience @innat, wanted to merge a Loss first before this. Ill revisit this now that we have done don

@bhack
Copy link
Contributor

bhack commented Sep 27, 2022

This is the oldest open PR. What is its status?

@LukeWood
Copy link
Contributor

This is the oldest open PR. What is its status?

To be honest it simply hasn’t been prioritized due to it not being a part of one of our flagship workflows (OD w/ RetinaNet, classification training in a reusable template with SOTA augmentation, and segmentation map prediction with model TBD)

one thing to double check is that the loss reduction matches the construct set forth in the rest of the losses (see smoothL1, focal as examples).

we can revisit this, but at the moment we are focusing effort on delivering those flagship workflows.

@LukeWood
Copy link
Contributor

This is the oldest open PR. What is its status?

To be honest it simply hasn’t been prioritized due to it not being a part of one of our flagship workflows (OD w/ RetinaNet, classification training in a reusable template with SOTA augmentation, and segmentation map prediction with model TBD)

one thing to double check is that the loss reduction matches the construct set forth in the rest of the losses (see smoothL1, focal as examples).

we can revisit this, but at the moment we are focusing effort on delivering those flagship workflows.

That explanation said @innat, feel free to rebase and re request a review. It’s almost ready so happy to prioritize this.

@innat
Copy link
Contributor Author

innat commented Sep 28, 2022

@LukeWood I may not able to work on it due to tight schedule. If it's prioritized, someone may need to take it from here. ( Same goes to Jaccard PR. )

@DavidLandup0
Copy link
Contributor

Can I take over this one and #449?

@DavidLandup0
Copy link
Contributor

Awesome, thanks! I'll be working on these in the upcoming days then.
What's left is the documentation, test cases and implementing the suggestions from the comments, right?

@innat
Copy link
Contributor Author

innat commented Oct 12, 2022

Awesome, thanks! I'll be working on these in the upcoming days then. What's left is the documentation, test cases and implementing the suggestions from the comments, right?

Sort of.

@DavidLandup0 DavidLandup0 mentioned this pull request Oct 28, 2022
5 tasks
@DavidLandup0
Copy link
Contributor

@innat Could you share the training script you mentioned if you still have it?

@innat
Copy link
Contributor Author

innat commented Oct 28, 2022

I am not sure which one you're referring to. Could you please more precise? ( I'm in leave, so if I've anything that is related, I will share in upcoming days.)

@DavidLandup0
Copy link
Contributor

Oh, I meant this:

Ps. Tested on actual training, works fine so far for 2D and 3D cases; now needs some polish.

Sorry to bug you while you're on leave - enjoy your time off! :)

Copy link
Contributor

@tanzhenyu tanzhenyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
A general comment -- is it possible to start with a simple one, say binary dice, test it with our existing models? This would help a LOT in speed-up the review process.

@DavidLandup0
Copy link
Contributor

Of course! I'll get an example up and running as soon as I finish the test cases for the other PR

@DavidLandup0
Copy link
Contributor

@tanzhenyu added an example run in the new PR #968 based on your training script for DeepLabV3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review PRs that need to be reviewed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DICE loss
8 participants