Skip to content

Commit

Permalink
WIP: replacing nump arrays with memory maps
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Oct 22, 2024
1 parent e6ea7bb commit 2eb5db4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
6 changes: 5 additions & 1 deletion dicee/dataset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,10 @@ class OnevsAllDataset(torch.utils.data.Dataset):
torch.utils.data.Dataset
"""

def __init__(self, train_set_idx: np.memmap, entity_idxs):
def __init__(self, train_set_idx: np.ndarray, entity_idxs):
super().__init__()
assert isinstance(train_set_idx, np.memmap)

assert isinstance(train_set_idx, np.ndarray)
assert len(train_set_idx) > 0
self.train_data = train_set_idx
Expand Down Expand Up @@ -298,6 +300,7 @@ def __init__(self, train_set_idx: np.ndarray, entity_idxs, relation_idxs, form,
label_smoothing_rate: float = 0.0):
super().__init__()
assert len(train_set_idx) > 0
assert isinstance(train_set_idx, np.memmap)
assert isinstance(train_set_idx, np.ndarray)
self.train_data = None
self.train_target = None
Expand Down Expand Up @@ -394,6 +397,7 @@ def __init__(self, train_set_idx: np.ndarray, entity_idxs, relation_idxs,
label_smoothing_rate=0.0):
super().__init__()
assert len(train_set_idx) > 0
assert isinstance(train_set_idx, np.memmap)
assert isinstance(train_set_idx, np.ndarray)
self.train_data = None
self.train_target = None
Expand Down
8 changes: 0 additions & 8 deletions dicee/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,6 @@ def __init__(self, dataset_dir: str = None,
LoadSaveToDisk(kg=self).save()
else:
LoadSaveToDisk(kg=self).load()
train_set_shape=self.train_set.shape
train_set_dtype=self.train_set.dtype

fp = np.memmap(self.path_for_serialization + '/memory_map_train_set.npy', dtype=train_set_dtype, mode='w+', shape=train_set_shape)
fp[:] = self.train_set[:]
self.train_set=fp
del fp

assert len(self.train_set) > 0, "Training set is empty"
self._describe()

Expand Down
13 changes: 10 additions & 3 deletions dicee/trainer/dice_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import lightning as pl

import gc

from typing import Union

from dicee.models.base_model import BaseKGE
from dicee.static_funcs import select_model
from dicee.callbacks import ASWA, Eval, KronE, PrintCallback, AccumulateEpochLossCallback, Perturb
Expand All @@ -17,7 +15,7 @@
import copy
from typing import List, Tuple
from ..knowledge_graph import KG

import numpy as np

def initialize_trainer(args, callbacks):
if args.trainer == 'torchCPUTrainer':
Expand Down Expand Up @@ -199,6 +197,15 @@ def initialize_dataloader(self, dataset: torch.utils.data.Dataset) -> torch.util
@timeit
def initialize_dataset(self, dataset: KG, form_of_labelling) -> torch.utils.data.Dataset:
print('Initializing Dataset...', end='\t')
train_set_shape=dataset.train_set.shape
train_set_dtype=dataset.train_set.dtype

fp = np.memmap(dataset.path_for_serialization + '/memory_map_train_set.npy', dtype=train_set_dtype, mode='w+', shape=train_set_shape)
fp[:] = dataset.train_set[:]
dataset.train_set=fp
del fp


train_dataset = construct_dataset(train_set=dataset.train_set,
valid_set=dataset.valid_set,
test_set=dataset.test_set,
Expand Down

0 comments on commit 2eb5db4

Please sign in to comment.