From b1ad484bff1f7515f306e31310f9e1aad4fef6fc Mon Sep 17 00:00:00 2001 From: Tom Grek Date: Sun, 28 Apr 2019 16:55:32 -0700 Subject: [PATCH] Remove phased rotate --- nn/rotate.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/nn/rotate.py b/nn/rotate.py index 33e576d..734b994 100644 --- a/nn/rotate.py +++ b/nn/rotate.py @@ -36,10 +36,7 @@ def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma, a=-self.embedding_range.item(), b=self.embedding_range.item()) - if model_name == 'pRotatE': - self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]])) - - if model_name not in ['TransE', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE']: + if model_name not in ['ComplEx', 'RotatE']: raise ValueError('model {} not supported'.format(model_name)) def forward(self, sample, mode='single'): @@ -111,8 +108,7 @@ def forward(self, sample, mode='single'): model_func = { 'ComplEx': self.ComplEx, - 'RotatE': self.RotatE, - 'pRotatE': self.pRotatE + 'RotatE': self.RotatE } if self.model_name in model_func: @@ -166,23 +162,6 @@ def RotatE(self, head, relation, tail, mode): score = self.gamma.item() - score.sum(dim=2) return score - def pRotatE(self, head, relation, tail, mode): - - phase_head = head/(self.embedding_range.item()/math.pi) - phase_relation = relation/(self.embedding_range.item()/math.pi) - phase_tail = tail/(self.embedding_range.item()/math.pi) - - if mode == 'head-batch': - score = phase_head + (phase_relation - phase_tail) - else: - score = (phase_head + phase_relation) - phase_tail - - score = torch.sin(score) - score = torch.abs(score) - - score = self.gamma.item() - score.sum(dim = 2) * self.modulus - return score - @staticmethod def train_step(model, optimizer, train_iterator, args): model.train()