Skip to content

Commit

Permalink
Onevsall with memory map
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Oct 22, 2024
1 parent 6e070e4 commit e6ea7bb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
12 changes: 5 additions & 7 deletions dicee/dataset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,21 @@ class OnevsAllDataset(torch.utils.data.Dataset):
torch.utils.data.Dataset
"""

def __init__(self, train_set_idx: np.ndarray, entity_idxs):
def __init__(self, train_set_idx: np.memmap, entity_idxs):
super().__init__()
assert isinstance(train_set_idx, np.ndarray)
assert len(train_set_idx) > 0
self.train_data = torch.LongTensor(train_set_idx)
self.train_data = train_set_idx
self.target_dim = len(entity_idxs)
self.collate_fn = None

def __len__(self):
return len(self.train_data)

def __getitem__(self, idx):
y_vec = torch.zeros(self.target_dim)
y_vec[self.train_data[idx, 2]] = 1
return self.train_data[idx, :2], y_vec


triple= torch.from_numpy(self.train_data[idx].copy()).long()
y_vec[triple[2]] = 1
return triple[:2], y_vec
class KvsAll(torch.utils.data.Dataset):
""" Creates a dataset for KvsAll training by inheriting from torch.utils.data.Dataset.
Let D denote a dataset for KvsAll training and be defined as D:= {(x,y)_i}_i ^N, where
Expand Down
9 changes: 9 additions & 0 deletions dicee/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys
import pandas as pd
import polars as pl
import numpy as np
from .read_preprocess_save_load_kg.util import load_numpy_ndarray
class KG:
""" Knowledge Graph """

Expand Down Expand Up @@ -80,6 +82,13 @@ def __init__(self, dataset_dir: str = None,
LoadSaveToDisk(kg=self).save()
else:
LoadSaveToDisk(kg=self).load()
train_set_shape=self.train_set.shape
train_set_dtype=self.train_set.dtype

fp = np.memmap(self.path_for_serialization + '/memory_map_train_set.npy', dtype=train_set_dtype, mode='w+', shape=train_set_shape)
fp[:] = self.train_set[:]
self.train_set=fp
del fp

assert len(self.train_set) > 0, "Training set is empty"
self._describe()
Expand Down
1 change: 0 additions & 1 deletion dicee/read_preprocess_save_load_kg/read_from_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def start(self) -> None:
backend=self.kg.backend)
if self.kg.add_noise_rate:
self.add_noisy_triples_into_training()

self.kg.raw_valid_set = None
self.kg.raw_test_set = None
elif self.kg.sparql_endpoint is not None:
Expand Down
7 changes: 0 additions & 7 deletions dicee/read_preprocess_save_load_kg/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def read_with_polars(data_path, read_only_few: int = None, sample_triples_ratio:
""" Load and Preprocess via Polars """
print(f'*** Reading {data_path} with Polars ***')
# (1) Load the data.
print('Reading with polars.read_csv with sep **t** ...')
df = polars.read_csv(data_path,
has_header=False,
low_memory=False,
Expand All @@ -122,12 +121,6 @@ def read_with_polars(data_path, read_only_few: int = None, sample_triples_ratio:
dtypes=[polars.String],
new_columns=['subject', 'relation', 'object'],
separator=" ") # \s+ doesn't work for polars
# parquet usage deprecated.
#if read_only_few is None:
# df = polars.read_parquet(data_path, use_pyarrow=True)
#else:
# df = polars.read_parquet(data_path, n_rows=read_only_few)
#
# (2) Sample from (1).
if sample_triples_ratio:
print(f'Subsampling {sample_triples_ratio} of input data {df.shape}...')
Expand Down

0 comments on commit e6ea7bb

Please sign in to comment.