Skip to content

Commit

Permalink
dice typo
Browse files Browse the repository at this point in the history
  • Loading branch information
MengzhangLI committed Jun 25, 2021
1 parent 79cc19e commit 5ba5acd
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tests/test_models/test_losses/test_tversky_loss.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import torch


def test_dice_lose():
def test_tversky_lose():
from mmseg.models import build_loss

# test dice loss with loss_type = 'multi_class'
# test tversky loss with loss_type = 'multi_class'
loss_cfg = dict(
type='TverskyLoss',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
tversky_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)
tversky_loss(logits, labels)

# test loss with class weights from file
import os
Expand All @@ -30,8 +30,8 @@ def test_dice_lose():
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
tversky_loss = build_loss(loss_cfg)
tversky_loss(logits, labels, ignore_index=None)

np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
Expand All @@ -40,13 +40,13 @@ def test_dice_lose():
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
tversky_loss = build_loss(loss_cfg)
tversky_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

# test dice loss with loss_type = 'binary'
# test tversky loss with loss_type = 'binary'
loss_cfg = dict(
type='TverskyLoss',
smooth=2,
Expand All @@ -56,7 +56,7 @@ def test_dice_lose():
ignore_index=0,
alpha=0.3,
beta=0.7)
dice_loss = build_loss(loss_cfg)
tversky_loss = build_loss(loss_cfg)
logits = torch.rand(8, 2, 4, 4)
labels = (torch.rand(8, 4, 4) * 2).long()
dice_loss(logits, labels)
tversky_loss(logits, labels)

0 comments on commit 5ba5acd

Please sign in to comment.