Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DualE implemented within the dice framework. #241

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dicee/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .clifford import Keci, KeciBase, CMult, DeCaL # noqa
from .pykeen_models import * # noqa
from .function_space import * # noqa
from .dualE import DualE
130 changes: 130 additions & 0 deletions dicee/models/dualE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
from .base_model import BaseKGE



class DualE(BaseKGE):
def __init__(self, args):
super().__init__(args)
self.name = 'DualE'
self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim)
self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim)
self.num_ent = self.num_entities


#Calculate the Dual Hamiltonian product
def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3):

h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3
h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3
h_1=a_0*c_1+a_1*c_0+a_2*c_3-a_3*c_2
h1_1=a_0*d_1+b_0*c_1+a_1*d_0+b_1*c_0+a_2*d_3+b_2*c_3-a_3*d_2-b_3*c_2
h_2=a_0*c_2-a_1*c_3+a_2*c_0+a_3*c_1
h1_2=a_0*d_2+b_0*c_2-a_1*d_3-b_1*c_3+a_2*d_0+b_2*c_0+a_3*d_1+b_3*c_1
h_3=a_0*c_3+a_1*c_2-a_2*c_1+a_3*c_0
h1_3=a_0*d_3+b_0*c_3+a_1*d_2+b_1*c_2-a_2*d_1-b_2*c_1+a_3*d_0+b_3*c_0

return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3)

#Normalization of relationship embedding
def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8):
denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2
denominator_1 = torch.sqrt(denominator_0)
#denominator_2 = torch.sqrt(r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2)
deno_cross = r_5 * r_1 + r_6 * r_2 + r_7 * r_3 + r_8 * r_4

r_5 = r_5 - deno_cross / denominator_0 * r_1
r_6 = r_6 - deno_cross / denominator_0 * r_2
r_7 = r_7 - deno_cross / denominator_0 * r_3
r_8 = r_8 - deno_cross / denominator_0 * r_4

r_1 = r_1 / denominator_1
r_2 = r_2 / denominator_1
r_3 = r_3 / denominator_1
r_4 = r_4 / denominator_1
#r_5 = r_5 / denominator_2
#r_6 = r_6 / denominator_2
#r_7 = r_7 / denominator_2
#r_8 = r_8 / denominator_2
return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8

#Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity
def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ):

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )

o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)


score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t
+ o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t)

return -torch.sum(score_r, -1)

def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ):

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )

o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)


score_r = torch.mm(o_1, e_1_t) + torch.mm(o_2 ,e_2_t) + torch.mm(o_3, e_3_t) + torch.mm(o_4, e_4_t)\
+ torch.mm(o_5, e_5_t) + torch.mm(o_6, e_6_t) + torch.mm(o_7, e_7_t) +torch.mm( o_8 , e_8_t)

return -score_r


def forward_triples(self, idx_triple):

head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple)


e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8)

e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8)

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8)

score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )


return score



def forward_k_vs_all(self,x):

# (1) Retrieve embeddings & Apply Dropout & Normalization.
head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x)

e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8)

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_ent_emb, 8)

e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(self.entity_embeddings.weight, 8)

e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = self.T(e_1_t), self.T(e_2_t), self.T(e_3_t),\
self.T(e_4_t), self.T(e_5_t), self.T(e_6_t), self.T(e_7_t), self.T(e_8_t)

score = self.kvsall_score(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )


return score

def T(self, x):

return x.transpose(1, 0)




2 changes: 1 addition & 1 deletion dicee/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_default_arguments(description=None):
parser.add_argument("--model", type=str,
default="Keci",
choices=["ComplEx", "Keci", "ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult",
"OMult", "Shallom", "DistMult", "TransE", "DeCaL",
"OMult", "Shallom", "DistMult", "TransE", "DualE",
"BytE",
"Pykeen_MuRE", "Pykeen_QuatE", "Pykeen_DistMult", "Pykeen_BoxE", "Pykeen_CP",
"Pykeen_HolE", "Pykeen_ProjE", "Pykeen_RotatE",
Expand Down
5 changes: 4 additions & 1 deletion dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import datetime
from typing import Tuple, List
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL,\
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL, DualE,\
ComplEx, AConEx, AConvO, AConvQ, ConvQ, ConvO, ConEx, QMult, OMult, Shallom, LFMult
from .models.pykeen_models import PykeenKGE
from .models.transformers import BytE
Expand Down Expand Up @@ -421,6 +421,9 @@ def intialize_model(args: dict,verbose=0) -> Tuple[object, str]:
elif model_name == 'DeCaL':
model =DeCaL(args=args)
form_of_labelling = 'EntityPrediction'
elif model_name == 'DualE':
model =DualE(args=args)
form_of_labelling = 'EntityPrediction'
else:
raise ValueError(f"--model_name: {model_name} is not found.")
return model, form_of_labelling
Expand Down
36 changes: 36 additions & 0 deletions tests/test_regression_DualE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dicee.executer import Execute
import pytest
from dicee.config import Namespace

class TestRegressionClifford:
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_k_vs_all(self):
args = Namespace()
args.model = 'DualE'
args.scoring_technique = 'KvsAll'
args.optim = 'Adam'
args.dataset_dir = 'KGs/UMLS'
args.num_epochs = 32
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_model = 'train_val_test'
dualE_result = Execute(args).start()

args = Namespace()
args.model = 'DeCaL'
args.scoring_technique = 'KvsAll'
args.optim = 'Adam'
args.p = 0
args.q = 1
args.r = 1
args.dataset_dir = 'KGs/UMLS'
args.num_epochs = 32
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_model = 'train_val_test'
decal_result = Execute(args).start()

assert decal_result["Train"]["MRR"] > dualE_result["Train"]["MRR"]
assert decal_result["Test"]["MRR"] > dualE_result["Test"]["MRR"]
Loading