Skip to content

Commit

Permalink
Merge pull request #241 from dice-group/DualE
Browse files Browse the repository at this point in the history
DualE implemented within the dice framework.
  • Loading branch information
Demirrr authored Mar 27, 2024
2 parents 2fa2dcc + e7d33c1 commit 2d2b945
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 2 deletions.
1 change: 1 addition & 0 deletions dicee/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .clifford import Keci, KeciBase, CMult, DeCaL # noqa
from .pykeen_models import * # noqa
from .function_space import * # noqa
from .dualE import DualE
130 changes: 130 additions & 0 deletions dicee/models/dualE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
from .base_model import BaseKGE



class DualE(BaseKGE):
def __init__(self, args):
super().__init__(args)
self.name = 'DualE'
self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim)
self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim)
self.num_ent = self.num_entities


#Calculate the Dual Hamiltonian product
def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3):

h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3
h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3
h_1=a_0*c_1+a_1*c_0+a_2*c_3-a_3*c_2
h1_1=a_0*d_1+b_0*c_1+a_1*d_0+b_1*c_0+a_2*d_3+b_2*c_3-a_3*d_2-b_3*c_2
h_2=a_0*c_2-a_1*c_3+a_2*c_0+a_3*c_1
h1_2=a_0*d_2+b_0*c_2-a_1*d_3-b_1*c_3+a_2*d_0+b_2*c_0+a_3*d_1+b_3*c_1
h_3=a_0*c_3+a_1*c_2-a_2*c_1+a_3*c_0
h1_3=a_0*d_3+b_0*c_3+a_1*d_2+b_1*c_2-a_2*d_1-b_2*c_1+a_3*d_0+b_3*c_0

return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3)

#Normalization of relationship embedding
def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8):
denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2
denominator_1 = torch.sqrt(denominator_0)
#denominator_2 = torch.sqrt(r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2)
deno_cross = r_5 * r_1 + r_6 * r_2 + r_7 * r_3 + r_8 * r_4

r_5 = r_5 - deno_cross / denominator_0 * r_1
r_6 = r_6 - deno_cross / denominator_0 * r_2
r_7 = r_7 - deno_cross / denominator_0 * r_3
r_8 = r_8 - deno_cross / denominator_0 * r_4

r_1 = r_1 / denominator_1
r_2 = r_2 / denominator_1
r_3 = r_3 / denominator_1
r_4 = r_4 / denominator_1
#r_5 = r_5 / denominator_2
#r_6 = r_6 / denominator_2
#r_7 = r_7 / denominator_2
#r_8 = r_8 / denominator_2
return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8

#Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity
def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ):

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )

o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)


score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t
+ o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t)

return -torch.sum(score_r, -1)

def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ):

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )

o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)


score_r = torch.mm(o_1, e_1_t) + torch.mm(o_2 ,e_2_t) + torch.mm(o_3, e_3_t) + torch.mm(o_4, e_4_t)\
+ torch.mm(o_5, e_5_t) + torch.mm(o_6, e_6_t) + torch.mm(o_7, e_7_t) +torch.mm( o_8 , e_8_t)

return -score_r


def forward_triples(self, idx_triple):

head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple)


e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8)

e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8)

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8)

score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )


return score



def forward_k_vs_all(self,x):

# (1) Retrieve embeddings & Apply Dropout & Normalization.
head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x)

e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8)

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_ent_emb, 8)

e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(self.entity_embeddings.weight, 8)

e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = self.T(e_1_t), self.T(e_2_t), self.T(e_3_t),\
self.T(e_4_t), self.T(e_5_t), self.T(e_6_t), self.T(e_7_t), self.T(e_8_t)

score = self.kvsall_score(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )


return score

def T(self, x):

return x.transpose(1, 0)




2 changes: 1 addition & 1 deletion dicee/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_default_arguments(description=None):
parser.add_argument("--model", type=str,
default="Keci",
choices=["ComplEx", "Keci", "ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult",
"OMult", "Shallom", "DistMult", "TransE", "DeCaL",
"OMult", "Shallom", "DistMult", "TransE", "DualE",
"BytE",
"Pykeen_MuRE", "Pykeen_QuatE", "Pykeen_DistMult", "Pykeen_BoxE", "Pykeen_CP",
"Pykeen_HolE", "Pykeen_ProjE", "Pykeen_RotatE",
Expand Down
5 changes: 4 additions & 1 deletion dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import datetime
from typing import Tuple, List
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL,\
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL, DualE,\
ComplEx, AConEx, AConvO, AConvQ, ConvQ, ConvO, ConEx, QMult, OMult, Shallom, LFMult
from .models.pykeen_models import PykeenKGE
from .models.transformers import BytE
Expand Down Expand Up @@ -421,6 +421,9 @@ def intialize_model(args: dict,verbose=0) -> Tuple[object, str]:
elif model_name == 'DeCaL':
model =DeCaL(args=args)
form_of_labelling = 'EntityPrediction'
elif model_name == 'DualE':
model =DualE(args=args)
form_of_labelling = 'EntityPrediction'
else:
raise ValueError(f"--model_name: {model_name} is not found.")
return model, form_of_labelling
Expand Down
36 changes: 36 additions & 0 deletions tests/test_regression_DualE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dicee.executer import Execute
import pytest
from dicee.config import Namespace

class TestRegressionClifford:
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_k_vs_all(self):
args = Namespace()
args.model = 'DualE'
args.scoring_technique = 'KvsAll'
args.optim = 'Adam'
args.dataset_dir = 'KGs/UMLS'
args.num_epochs = 32
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_model = 'train_val_test'
dualE_result = Execute(args).start()

args = Namespace()
args.model = 'DeCaL'
args.scoring_technique = 'KvsAll'
args.optim = 'Adam'
args.p = 0
args.q = 1
args.r = 1
args.dataset_dir = 'KGs/UMLS'
args.num_epochs = 32
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_model = 'train_val_test'
decal_result = Execute(args).start()

assert decal_result["Train"]["MRR"] > dualE_result["Train"]["MRR"]
assert decal_result["Test"]["MRR"] > dualE_result["Test"]["MRR"]

0 comments on commit 2d2b945

Please sign in to comment.