-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
38 lines (26 loc) · 912 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
class MergeDict(dict):
class Merge(object):
def __init__(self, data):
self.dict = data
def __setitem__(self, name, other):
for k1 in other:
if k1 not in self.dict:
self.dict[k1] = {}
for k2 in other[k1]:
k_new = '{}/{}'.format(name, k2)
assert k_new not in self.dict[k1]
self.dict[k1][k_new] = other[k1][k2]
@property
def merge(self):
return self.Merge(self)
def take_until_token(seq, token):
if token in seq:
return seq[:seq.index(token)]
else:
return seq
def label_smoothing(input, smoothing):
return input * (1 - smoothing) + smoothing / input.size(2)
def one_hot(input, num_classes):
input = torch.eye(num_classes, dtype=torch.float, device=input.device)[input]
return input