(this issue is just for my personal self-learning ) not flexible enough: paddle.scatter() #43260
Labels
PFCC
Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc
status/close
已关闭
type/docs
文档问题
问题描述 Please describe your issue
"torch code, can be replaced with torch.nn.functional.one_hot"
def one_hot_torch(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
“paddle version”
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.astype("int64").reshape([-1])
out = paddle.nn.functional.one_hot(x, num_classes=num_classes)
out = paddle.where(out > 0.5, on_value, off_value)
out = paddle.to_tensor(out, place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace())
return out
TODO: When I study the above code, I find out that compared to torch.scatter(), paddle.scatter() is not flexible enough. The key problems are how to "index" and how to handle the case when there are repeated numbers in "index" especially in GPU mode ( simultaneous running threads that work on the same element in a tensor may have conflicts ???)
The following two PRs may have the same problems:
paddle.index_add
paddle.index_fill
The text was updated successfully, but these errors were encountered: