Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
try to fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjscience committed Oct 14, 2019
1 parent 845ad82 commit 13a35e6
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,24 +256,24 @@ When the number of samples for labels are very unbalanced, applying different we

```{.python .input}
class WeightedSoftmaxCE(nn.HybridBlock):
class WeightedSoftmaxCE(nn.Block):
def __init__(self, sparse_label=True, from_logits=False, **kwargs):
super(WeightedSoftmaxCE, self).__init__(**kwargs)
with self.name_scope():
self.sparse_label = sparse_label
self.from_logits = from_logits
def hybrid_forward(self, F, pred, label, class_weight, depth=None):
def forward(self, pred, label, class_weight, depth=None):
if self.sparse_label:
label = F.reshape(label, shape=(-1, ))
label = F.one_hot(label, depth)
label = nd.reshape(label, shape=(-1, ))
label = nd.one_hot(label, depth)
if not self.from_logits:
pred = F.log_softmax(pred, -1)
pred = nd.log_softmax(pred, -1)
weight_label = F.broadcast_mul(label, class_weight)
loss = -F.sum(pred * weight_label, axis=-1)
weight_label = nd.broadcast_mul(label, class_weight)
loss = -nd.sum(pred * weight_label, axis=-1)
# return F.mean(loss, axis=0, exclude=True)
# return nd.mean(loss, axis=0, exclude=True)
return loss
```
Expand Down Expand Up @@ -377,7 +377,7 @@ def calculate_loss(x, y, model, loss, class_weight, penal_coeff):
if loss_name == 'sce':
l = loss(pred, y)
elif loss_name == 'wsce':
l = loss(pred, y, class_weight, nd.array(class_weight.shape[0], ctx=x.context))
l = loss(pred, y, class_weight, class_weight.shape[0])
# penalty
diversity_penalty = nd.batch_dot(att, nd.transpose(att, axes=(0, 2, 1))
Expand Down

0 comments on commit 13a35e6

Please sign in to comment.