Skip to content

Commit

Permalink
Added parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
Graphite4 authored and swojciechowski committed Aug 16, 2022
1 parent ad6d4f2 commit d1263ea
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions weles/evaluation/Evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,24 @@
from os import path, listdir
from hashlib import md5
import inspect
from joblib import Parallel, delayed

VERBOSE_COLUMNS = 80


class Evaluator():
def __init__(self, datasets, protocol=(1, 5, None), store=None):
def __init__(self, datasets, protocol=(1, 5, None), store=None, parallel_jobs=1):
self.datasets = datasets
self.protocol = protocol
self.store = store
self.parallel_jobs = parallel_jobs

self.m, self.k, self.random_state = self.protocol
if self.random_state is None:
self.store = None




# Check storage
if store is not None:
Expand All @@ -35,6 +44,30 @@ def __init__(self, datasets, protocol=(1, 5, None), store=None):
else:
print("Store is none")

def process_one_dataset(self, dataset_name, dataset_id, verbose=False):
X, y = self.datasets[dataset_name]
skf = RepeatedStratifiedKFold(n_splits=self.k, n_repeats=self.m, random_state=self.random_state)
bar = tqdm(skf.split(X, y), disable=not verbose, leave=True, position=dataset_id, colour="CYAN")
bar.set_description(dataset_name.ljust(20))
for fold_id, (train, test) in enumerate(bar):
str_gt = self._storage_key_gt(X, y, fold_id)
self.true_values[dataset_id, fold_id] = y[test]

for clf_id, clf_name in enumerate(self.clfs):
str_clf = self._storage_key_pred(str_gt,
self.clfs[clf_name])

if self.store is not None and str_clf + ".npy" in self.stored:
y_pred = np.load("%s/%s.npy" % (self.store, str_clf))
else:
clf = clone(self.clfs[clf_name])
clf.fit(X[train], y[train])
self.callbacks.after_clf_fit(dataset_name, fold_id, clf_name, clf)
y_pred = clf.predict(X[test])
if self.store is not None:
np.save("%s/%s" % (self.store, str_clf), y_pred)
self.predictions[dataset_id, clf_id, fold_id] = y_pred

def process(self, clfs, verbose=False):
"""
This function is used to process declared evaluation protocol
Expand All @@ -48,39 +81,19 @@ def process(self, clfs, verbose=False):
self.clfs = clfs

# Establish protocol
self.m, self.k, self.random_state = self.protocol
if self.random_state is None:
self.store = None

skf = RepeatedStratifiedKFold(n_splits=self.k, n_repeats=self.m,
random_state=self.random_state)
self.predictions = np.zeros([len(self.datasets), len(self.clfs),
self.m * self.k], dtype=object)
self.true_values = np.zeros([len(self.datasets), self.m * self.k],
dtype=object)

bar = tqdm(self.datasets, disable=not verbose)

# Iterate over datasets
for dataset_id, dataset_name in enumerate(bar):
bar.set_description(dataset_name.ljust(20))
X, y = self.datasets[dataset_name]
for fold_id, (train, test) in enumerate(skf.split(X, y)):
str_gt = self._storage_key_gt(X, y, fold_id)
self.true_values[dataset_id, fold_id] = y[test]

for clf_id, clf_name in enumerate(self.clfs):
str_clf = self._storage_key_pred(str_gt,
self.clfs[clf_name])

if self.store is not None and str_clf+".npy" in self.stored:
y_pred = np.load("%s/%s.npy" % (self.store, str_clf))
else:
clf = clone(self.clfs[clf_name])
clf.fit(X[train], y[train])
y_pred = clf.predict(X[test])
if self.store is not None:
np.save("%s/%s" % (self.store, str_clf), y_pred)
self.predictions[dataset_id, clf_id, fold_id] = y_pred
Parallel(n_jobs=self.parallel_jobs, require='sharedmem')(
delayed(self.process_one_dataset)
(dataset_name, dataset_id, verbose)
for dataset_id, dataset_name in enumerate(self.datasets.keys())
)

return self

Expand Down

0 comments on commit d1263ea

Please sign in to comment.