Skip to content

Commit

Permalink
Merge pull request #283 from dice-group/refactor
Browse files Browse the repository at this point in the history
Refactoring before new release
  • Loading branch information
Demirrr authored Dec 1, 2024
2 parents 787b535 + 4326478 commit 3332aec
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 30 deletions.
10 changes: 0 additions & 10 deletions dicee/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,6 @@ def vocab_preparation(self, dataset) -> None:
else:
self.ee_vocab = dataset.ee_vocab.result()

"""
if isinstance(dataset.constraints, tuple):
self.domain_constraints_per_rel, self.range_constraints_per_rel = dataset.constraints
else:
try:
self.domain_constraints_per_rel, self.range_constraints_per_rel = dataset.constraints.result()
except RuntimeError:
print('Domain constraint exception occurred')
"""

self.num_entities = dataset.num_entities
self.num_relations = dataset.num_relations
self.func_triple_to_bpe_representation = dataset.func_triple_to_bpe_representation
Expand Down
7 changes: 5 additions & 2 deletions dicee/static_funcs_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
the filtered missing tail entity rank
:param model:
:param triple_idx:
:param num_entities:
:param er_vocab:
:param re_vocab:
:param info:
:param batch_size:
:param chunk_size:
Expand Down Expand Up @@ -127,7 +130,7 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
), dim=1)

# Predict scores for missing tails
preds_tails = model.forward_triples(x_tails)
preds_tails = model(x_tails)
preds_tails = preds_tails.view(batch_size_current, chunk_size_current)
predictions_tails[:, chunk_start:chunk_end] = preds_tails
del x_tails
Expand All @@ -140,7 +143,7 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
), dim=1)

# Predict scores for missing heads
preds_heads = model.forward_triples(x_heads)
preds_heads = model(x_heads)
preds_heads = preds_heads.view(batch_size_current, chunk_size_current)
predictions_heads[:, chunk_start:chunk_end] = preds_heads
del x_heads
Expand Down
6 changes: 3 additions & 3 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def fit(self, *args, **kwargs):
timeout=0,
worker_init_fn=None,
persistent_workers=False)
if batch_rt is not None:
expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")
#if batch_rt is not None:
# expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
# print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")

# () Number of batches to reach a single epoch.
num_of_batches = len(train_dataloader)
Expand Down
28 changes: 13 additions & 15 deletions tests/test_regression_model_paralelisim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@
from dicee.config import Namespace
import os
import torch
class TestRegressionModelParallel:
class TestRegressionTensorParallel:
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_k_vs_all(self):
# @TODO:
"""
if torch.cuda.is_available():
args = Namespace()
args.model = 'Keci'
args.trainer = "MP"
args.trainer = "TP"
args.scoring_technique = "KvsAll" # 1vsAll, or AllvsAll, or NegSample
args.dataset_dir = "KGs/UMLS"
args.path_to_store_single_run = "Keci_UMLS"
args.num_epochs = 100
# CD: TP currently doesn't work with path_to_store_single_run and eval.
#args.path_to_store_single_run = "Keci_UMLS"
args.optim="Adopt"
args.num_epochs = 10
args.embedding_dim = 32
args.batch_size = 1024
args.batch_size = 32
args.lr=0.1
reports = Execute(args).start()
assert reports["Train"]["MRR"] >= 0.990
assert reports["Test"]["MRR"] >= 0.810
write_csv_from_model_parallel(path="Keci_UMLS")
assert os.path.exists("Keci_UMLS/entity_embeddings.csv")
assert os.path.exists("Keci_UMLS/relation_embeddings.csv")
assert reports["Train"]["MRR"] >= 0.60
assert reports["Test"]["MRR"] >= 0.58
#assert os.path.exists("Keci_UMLS/entity_embeddings.csv")
#assert os.path.exists("Keci_UMLS/relation_embeddings.csv")

os.system(f'rm -rf Keci_UMLS')
"""
#os.system(f'rm -rf Keci_UMLS')

0 comments on commit 3332aec

Please sign in to comment.