diff --git a/dicee/dataset_classes.py b/dicee/dataset_classes.py index 2da9d834..33733d50 100644 --- a/dicee/dataset_classes.py +++ b/dicee/dataset_classes.py @@ -242,8 +242,10 @@ class OnevsAllDataset(torch.utils.data.Dataset): torch.utils.data.Dataset """ - def __init__(self, train_set_idx: np.memmap, entity_idxs): + def __init__(self, train_set_idx: np.ndarray, entity_idxs): super().__init__() + assert isinstance(train_set_idx, np.memmap) + assert isinstance(train_set_idx, np.ndarray) assert len(train_set_idx) > 0 self.train_data = train_set_idx @@ -298,6 +300,7 @@ def __init__(self, train_set_idx: np.ndarray, entity_idxs, relation_idxs, form, label_smoothing_rate: float = 0.0): super().__init__() assert len(train_set_idx) > 0 + assert isinstance(train_set_idx, np.memmap) assert isinstance(train_set_idx, np.ndarray) self.train_data = None self.train_target = None @@ -394,6 +397,7 @@ def __init__(self, train_set_idx: np.ndarray, entity_idxs, relation_idxs, label_smoothing_rate=0.0): super().__init__() assert len(train_set_idx) > 0 + assert isinstance(train_set_idx, np.memmap) assert isinstance(train_set_idx, np.ndarray) self.train_data = None self.train_target = None diff --git a/dicee/knowledge_graph.py b/dicee/knowledge_graph.py index 87a873e3..a21586ab 100644 --- a/dicee/knowledge_graph.py +++ b/dicee/knowledge_graph.py @@ -82,14 +82,6 @@ def __init__(self, dataset_dir: str = None, LoadSaveToDisk(kg=self).save() else: LoadSaveToDisk(kg=self).load() - train_set_shape=self.train_set.shape - train_set_dtype=self.train_set.dtype - - fp = np.memmap(self.path_for_serialization + '/memory_map_train_set.npy', dtype=train_set_dtype, mode='w+', shape=train_set_shape) - fp[:] = self.train_set[:] - self.train_set=fp - del fp - assert len(self.train_set) > 0, "Training set is empty" self._describe() diff --git a/dicee/trainer/dice_trainer.py b/dicee/trainer/dice_trainer.py index 99134cbf..bc357cb3 100644 --- a/dicee/trainer/dice_trainer.py +++ b/dicee/trainer/dice_trainer.py @@ -1,9 +1,7 @@ import lightning as pl import gc - from typing import Union - from dicee.models.base_model import BaseKGE from dicee.static_funcs import select_model from dicee.callbacks import ASWA, Eval, KronE, PrintCallback, AccumulateEpochLossCallback, Perturb @@ -17,7 +15,7 @@ import copy from typing import List, Tuple from ..knowledge_graph import KG - +import numpy as np def initialize_trainer(args, callbacks): if args.trainer == 'torchCPUTrainer': @@ -199,6 +197,15 @@ def initialize_dataloader(self, dataset: torch.utils.data.Dataset) -> torch.util @timeit def initialize_dataset(self, dataset: KG, form_of_labelling) -> torch.utils.data.Dataset: print('Initializing Dataset...', end='\t') + train_set_shape=dataset.train_set.shape + train_set_dtype=dataset.train_set.dtype + + fp = np.memmap(dataset.path_for_serialization + '/memory_map_train_set.npy', dtype=train_set_dtype, mode='w+', shape=train_set_shape) + fp[:] = dataset.train_set[:] + dataset.train_set=fp + del fp + + train_dataset = construct_dataset(train_set=dataset.train_set, valid_set=dataset.valid_set, test_set=dataset.test_set,