-
Notifications
You must be signed in to change notification settings - Fork 7
/
tree_crf.py
50 lines (42 loc) · 1.74 KB
/
tree_crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch_model_utils as tmu
from torch import nn
from torch.distributions.utils import lazy_property
class TreeCRFVanilla(nn.Module):
def __init__(self, log_potentials, lengths=None):
self.log_potentials = log_potentials
self.lengths = lengths
return
@lazy_property
def entropy(self):
batch_size = self.log_potentials.size(0)
device = self.log_potentials.device
return torch.zeros(batch_size).to(device)
@lazy_property
def partition(self):
# Inside algorithm
device = self.log_potentials.device
batch_size = self.log_potentials.size(0)
max_len = self.log_potentials.size(1)
label_size = self.log_potentials.size(3)
beta = torch.zeros_like(self.log_potentials).to(device)
for i in range(max_len):
beta[:, i, i] = self.log_potentials[:, i, i]
for d in range(1, max_len):
for i in range(max_len - d):
j = i + d
before_lse_1 = beta[:, i, i:j].view(batch_size, d, label_size, 1)
before_lse_2 = beta[:, i + 1: j + 1, j].view(batch_size, d, 1, label_size)
before_lse = (before_lse_1 + before_lse_2).reshape(batch_size, -1)
after_lse = torch.logsumexp(before_lse, -1).view(batch_size, 1)
beta[:, i, j] = self.log_potentials[:, i, j] + after_lse
if (self.lengths is None):
before_lse = beta[:, 0, max_len - 1]
else:
before_lse = tmu.batch_index_select(beta[:, 0], self.lengths - 1)
log_z = torch.logsumexp(before_lse, -1)
return log_z
@lazy_property
def argmax(self):
raise NotImplementedError('slow argmax not implemented!')
return